mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of https://github.com/Jittor/jittor
This commit is contained in:
commit
c1c29deaa9
|
@ -1,2 +1,28 @@
|
||||||
Dockerfile
|
Dockerfile
|
||||||
**/publish.py
|
**/publish.py
|
||||||
|
my
|
||||||
|
.git
|
||||||
|
.refresh
|
||||||
|
__pycache__
|
||||||
|
.ipynb_checkpoints/
|
||||||
|
.vscode/
|
||||||
|
__res/
|
||||||
|
perf.data
|
||||||
|
perf.data.old
|
||||||
|
*.swp
|
||||||
|
*.ipynb
|
||||||
|
*.pdf
|
||||||
|
*.zip
|
||||||
|
*.tgz
|
||||||
|
test.py
|
||||||
|
extern/mkl/mkldnn_lnx*/*
|
||||||
|
data/
|
||||||
|
build/
|
||||||
|
venv/
|
||||||
|
*.md
|
||||||
|
!*.src.md
|
||||||
|
!README.md
|
||||||
|
!README.cn.md
|
||||||
|
python/jittor.egg-info
|
||||||
|
dist/
|
||||||
|
!doc/source/*
|
||||||
|
|
|
@ -38,12 +38,14 @@ RUN pip3 install matplotlib
|
||||||
|
|
||||||
RUN apt install openmpi-bin openmpi-common libopenmpi-dev -y
|
RUN apt install openmpi-bin openmpi-common libopenmpi-dev -y
|
||||||
|
|
||||||
|
RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
|
||||||
|
|
||||||
|
RUN pip3 uninstall jittor -y
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
RUN pip3 install . --timeout 100
|
RUN pip3 install . --timeout 100
|
||||||
|
|
||||||
RUN python3.7 -m jittor.test.test_example
|
RUN python3.7 -m jittor.test.test_example
|
||||||
|
|
||||||
RUN rm -rf ~/.cache/jittor/default
|
|
||||||
|
|
||||||
CMD python3.7 -m jittor.notebook --allow-root --ip=0.0.0.0
|
CMD python3.7 -m jittor.notebook --allow-root --ip=0.0.0.0
|
|
@ -24,7 +24,7 @@ CubArgReduceOp::CubArgReduceOp(Var* x, Var* offsets, NanoString op, bool keepdim
|
||||||
: x(x), offsets(offsets), op(op), keepdims(keepdims) {
|
: x(x), offsets(offsets), op(op), keepdims(keepdims) {
|
||||||
flags.set(NodeFlags::_cpu, 0);
|
flags.set(NodeFlags::_cpu, 0);
|
||||||
flags.set(NodeFlags::_cuda, 1);
|
flags.set(NodeFlags::_cuda, 1);
|
||||||
ASSERT(offsets->dtype()==ns_int || offsets->dtype()==ns_int32);
|
ASSERT(offsets->dtype()==ns_int32);
|
||||||
y = create_output(nullptr, ns_int32);
|
y = create_output(nullptr, ns_int32);
|
||||||
y_key = create_output(nullptr, x->dtype());
|
y_key = create_output(nullptr, x->dtype());
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,7 @@ CubArgsortOp::CubArgsortOp(Var* x, Var* indexes, Var* offsets, bool descending,
|
||||||
: x(x), indexes(indexes), offsets(offsets), descending(descending) {
|
: x(x), indexes(indexes), offsets(offsets), descending(descending) {
|
||||||
flags.set(NodeFlags::_cpu, 0);
|
flags.set(NodeFlags::_cpu, 0);
|
||||||
flags.set(NodeFlags::_cuda, 1);
|
flags.set(NodeFlags::_cuda, 1);
|
||||||
ASSERT(offsets->dtype()==ns_int || offsets->dtype()==ns_int32);
|
ASSERT(offsets->dtype()==ns_int32);
|
||||||
y = create_output(nullptr, dtype);
|
y = create_output(nullptr, dtype);
|
||||||
y_key = create_output(nullptr, x->dtype());
|
y_key = create_output(nullptr, x->dtype());
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ namespace jittor {
|
||||||
|
|
||||||
struct CurandRandomOp : Op {
|
struct CurandRandomOp : Op {
|
||||||
Var* output;
|
Var* output;
|
||||||
CurandRandomOp(NanoVector shape, NanoString dtype=ns_float);
|
CurandRandomOp(NanoVector shape, NanoString dtype=ns_float32);
|
||||||
|
|
||||||
const char* name() const override { return "curand_random"; }
|
const char* name() const override { return "curand_random"; }
|
||||||
DECLARE_jit_run;
|
DECLARE_jit_run;
|
||||||
|
|
|
@ -101,6 +101,17 @@ const char *_cudaGetErrorEnum(NppStatus error);
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void peek(T result, char const *const func, const char *const file,
|
||||||
|
int const line) {
|
||||||
|
if (result) {
|
||||||
|
// DEVICE_RESET
|
||||||
|
LOGe << "Peek CUDA error at" << file >> ":" >> line << " code="
|
||||||
|
>> static_cast<unsigned int>(result) >> "(" << _cudaGetErrorEnum(result) << ")"
|
||||||
|
<< func;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void check(T result, char const *const func, const char *const file,
|
void check(T result, char const *const func, const char *const file,
|
||||||
int const line) {
|
int const line) {
|
||||||
|
@ -116,6 +127,7 @@ void check(T result, char const *const func, const char *const file,
|
||||||
// This will output the proper CUDA error strings in the event
|
// This will output the proper CUDA error strings in the event
|
||||||
// that a CUDA host call returns an error
|
// that a CUDA host call returns an error
|
||||||
#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__)
|
#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__)
|
||||||
|
#define peekCudaErrors(val) peek((val), #val, __FILE__, __LINE__)
|
||||||
|
|
||||||
// This will output the proper error string when calling cudaGetLastError
|
// This will output the proper error string when calling cudaGetLastError
|
||||||
#define getLastCudaError(msg) __getLastCudaError(msg, __FILE__, __LINE__)
|
#define getLastCudaError(msg) __getLastCudaError(msg, __FILE__, __LINE__)
|
||||||
|
|
|
@ -83,7 +83,7 @@ mpi_initer() {
|
||||||
MPI_CHECK(MPI_Init(NULL, NULL));
|
MPI_CHECK(MPI_Init(NULL, NULL));
|
||||||
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
|
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
|
||||||
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank));
|
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank));
|
||||||
|
|
||||||
//calculating localRank based on hostname which is used in selecting a GPU
|
//calculating localRank based on hostname which is used in selecting a GPU
|
||||||
uint64_t hostHashs[mpi_world_rank];
|
uint64_t hostHashs[mpi_world_rank];
|
||||||
char hostname[1024];
|
char hostname[1024];
|
||||||
|
|
|
@ -33,7 +33,7 @@ namespace jittor {
|
||||||
|
|
||||||
struct CustomOp : Op {
|
struct CustomOp : Op {
|
||||||
Var* output;
|
Var* output;
|
||||||
CustomOp(NanoVector shape, NanoString dtype=ns_float);
|
CustomOp(NanoVector shape, NanoString dtype=ns_float32);
|
||||||
|
|
||||||
const char* name() const override { return "custom"; }
|
const char* name() const override { return "custom"; }
|
||||||
DECLARE_jit_run;
|
DECLARE_jit_run;
|
||||||
|
|
|
@ -273,6 +273,7 @@ Var.start_grad = Var.detach_inplace = detach_inplace
|
||||||
|
|
||||||
def unsqueeze(x, dim):
|
def unsqueeze(x, dim):
|
||||||
shape = list(x.shape)
|
shape = list(x.shape)
|
||||||
|
if dim < 0: dim += len(shape) + 1
|
||||||
assert dim <= len(shape)
|
assert dim <= len(shape)
|
||||||
return x.reshape(shape[:dim] + [1] + shape[dim:])
|
return x.reshape(shape[:dim] + [1] + shape[dim:])
|
||||||
Var.unsqueeze = unsqueeze
|
Var.unsqueeze = unsqueeze
|
||||||
|
@ -304,11 +305,11 @@ Var.masked_fill = masked_fill
|
||||||
def sqr(x): return x*x
|
def sqr(x): return x*x
|
||||||
Var.sqr = sqr
|
Var.sqr = sqr
|
||||||
|
|
||||||
def argmax(x, dim:int, keepdims:bool=False):
|
def argmax(x, dim, keepdims:bool=False):
|
||||||
return x.arg_reduce("max", dim, keepdims)
|
return x.arg_reduce("max", dim, keepdims)
|
||||||
Var.argmax = argmax
|
Var.argmax = argmax
|
||||||
|
|
||||||
def argmin(x, dim:int, keepdims:bool=False):
|
def argmin(x, dim, keepdims:bool=False):
|
||||||
return x.arg_reduce("min", dim, keepdims)
|
return x.arg_reduce("min", dim, keepdims)
|
||||||
Var.argmin = argmin
|
Var.argmin = argmin
|
||||||
|
|
||||||
|
@ -321,13 +322,54 @@ def attrs(var):
|
||||||
}
|
}
|
||||||
Var.attrs = attrs
|
Var.attrs = attrs
|
||||||
|
|
||||||
def fetch(vars, func, *args, **kw):
|
def fetch(*args):
|
||||||
core.fetch(vars, lambda *results: func(*results, *args, **kw))
|
''' Async fetch vars with function closure.
|
||||||
|
|
||||||
|
Example 1::
|
||||||
|
|
||||||
def fetch_var(var, func, *args, **kw):
|
for img,label in enumerate(your_dataset):
|
||||||
core.fetch([var], lambda a: func(a, *args, **kw))
|
pred = your_model(img)
|
||||||
Var.fetch = fetch_var
|
loss = critic(pred, label)
|
||||||
del fetch_var
|
acc = accuracy(pred, label)
|
||||||
|
jt.fetch(acc, loss,
|
||||||
|
lambda acc, loss:
|
||||||
|
print(f"loss:{loss} acc:{acc}"
|
||||||
|
)
|
||||||
|
|
||||||
|
Example 2::
|
||||||
|
|
||||||
|
for i,(img,label) in enumerate(your_dataset):
|
||||||
|
pred = your_model(img)
|
||||||
|
loss = critic(pred, label)
|
||||||
|
acc = accuracy(pred, label)
|
||||||
|
# variable i will be bind into function closure
|
||||||
|
jt.fetch(i, acc, loss,
|
||||||
|
lambda i, acc, loss:
|
||||||
|
print(f"#{i}, loss:{loss} acc:{acc}"
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
assert len(args)>=1
|
||||||
|
func = args[-1]
|
||||||
|
assert callable(func)
|
||||||
|
args = list(args[:-1])
|
||||||
|
if len(args)>0 and isinstance(args[0], Sequence) \
|
||||||
|
and len(args[0])>=1 and isinstance(args[0][0], Var):
|
||||||
|
raise TypeError("jt.Var should not inside a list or tuple.")
|
||||||
|
|
||||||
|
var_map = []
|
||||||
|
variables = []
|
||||||
|
for i, v in enumerate(args):
|
||||||
|
if isinstance(v, Var):
|
||||||
|
variables.append(v)
|
||||||
|
var_map.append(i)
|
||||||
|
args[i] = None
|
||||||
|
def callback(*results):
|
||||||
|
for i,v in enumerate(results):
|
||||||
|
args[var_map[i]] = v
|
||||||
|
func(*args)
|
||||||
|
core.ops.fetch(variables, callback)
|
||||||
|
|
||||||
|
Var.fetch = fetch
|
||||||
|
|
||||||
def display_memory_info():
|
def display_memory_info():
|
||||||
import inspect, os
|
import inspect, os
|
||||||
|
@ -439,11 +481,11 @@ class Module:
|
||||||
end = 0
|
end = 0
|
||||||
for k in key_:
|
for k in key_:
|
||||||
if isinstance(v, nn.Sequential):
|
if isinstance(v, nn.Sequential):
|
||||||
if np.int(k) >= len(v.layers):
|
if ori_int(k) >= len(v.layers):
|
||||||
end = 1
|
end = 1
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
v = v[np.int(k)]
|
v = v[ori_int(k)]
|
||||||
else:
|
else:
|
||||||
if hasattr(v, k):
|
if hasattr(v, k):
|
||||||
v = getattr(v, k)
|
v = getattr(v, k)
|
||||||
|
@ -574,12 +616,23 @@ def jittor_exit():
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
core.sync_all(True)
|
core.sync_all(True)
|
||||||
|
core.cleanup()
|
||||||
atexit.register(jittor_exit)
|
atexit.register(jittor_exit)
|
||||||
|
|
||||||
Var.__str__ = lambda x: str(x.data)
|
Var.__str__ = lambda x: str(x.data)
|
||||||
Var.__repr__ = lambda x: str(x.data)
|
Var.__repr__ = lambda x: str(x.data)
|
||||||
Var.peek = lambda x: f"{x.dtype}{x.shape}"
|
Var.peek = lambda x: f"{x.dtype}{x.shape}"
|
||||||
|
|
||||||
|
|
||||||
|
ori_int = int
|
||||||
|
|
||||||
|
int = int32
|
||||||
|
Var.int = Var.int32
|
||||||
|
float = float32
|
||||||
|
Var.float = Var.float32
|
||||||
|
double = float64
|
||||||
|
Var.double = Var.float64
|
||||||
|
|
||||||
from . import nn
|
from . import nn
|
||||||
from .nn import matmul
|
from .nn import matmul
|
||||||
from . import contrib
|
from . import contrib
|
||||||
|
|
|
@ -345,7 +345,7 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
|
||||||
with open(os.path.join(jittor_path, header), encoding='utf8') as f:
|
with open(os.path.join(jittor_path, header), encoding='utf8') as f:
|
||||||
src = f.read()
|
src = f.read()
|
||||||
# XxxXxxOp(args)
|
# XxxXxxOp(args)
|
||||||
res = re.findall(pybind_attrs_reg + '('+name2+"\\([^\\n]*\\))", src, re.S)
|
res = re.findall(pybind_attrs_reg + '[^~]('+name2+"\\([^\\n]*\\))", src, re.S)
|
||||||
assert len(res) >= 1, "Wrong op args in " + header
|
assert len(res) >= 1, "Wrong op args in " + header
|
||||||
# registe op
|
# registe op
|
||||||
cc_name = os.path.join(jittor_path, header[:-2] + ".cc")
|
cc_name = os.path.join(jittor_path, header[:-2] + ".cc")
|
||||||
|
@ -908,14 +908,14 @@ with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f:
|
||||||
f.write(jit_src)
|
f.write(jit_src)
|
||||||
cc_flags += f' -I{cache_path} '
|
cc_flags += f' -I{cache_path} '
|
||||||
# gen pyjt
|
# gen pyjt
|
||||||
pyjt_compiler.compile(cache_path, jittor_path)
|
pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path)
|
||||||
|
|
||||||
# initialize order:
|
# initialize order:
|
||||||
# 1. registers
|
# 1. registers
|
||||||
# 2. generate source
|
# 2. generate source
|
||||||
# 3. op_utils
|
# 3. op_utils
|
||||||
# 4. other
|
# 4. other
|
||||||
files2 = run_cmd(f'find "{os.path.join(cache_path, "gen")}" | grep "cc$"').splitlines()
|
files2 = pyjt_gen_src
|
||||||
files4 = run_cmd('find -L src | grep "cc$"', jittor_path).splitlines()
|
files4 = run_cmd('find -L src | grep "cc$"', jittor_path).splitlines()
|
||||||
at_beginning = [
|
at_beginning = [
|
||||||
"src/ops/op_utils.cc",
|
"src/ops/op_utils.cc",
|
||||||
|
|
|
@ -258,7 +258,7 @@ def generate_error_code_from_func_header(func_head, target_scope_name, name, dfs
|
||||||
|
|
||||||
LOG.vvv("gen err from func_head", func_head)
|
LOG.vvv("gen err from func_head", func_head)
|
||||||
args = func_head[1:].split(")")[0].split(",")
|
args = func_head[1:].split(")")[0].split(",")
|
||||||
error_code = f" << \"Wrong inputs arguments, Please refer to examples(e.g. {help_cmd}).\""
|
error_code = f" << \"Wrong inputs arguments, Please refer to examples({help_cmd}).\""
|
||||||
error_code += r' << "\n\nTypes of your inputs are:\n"'
|
error_code += r' << "\n\nTypes of your inputs are:\n"'
|
||||||
for arg in args:
|
for arg in args:
|
||||||
arg = arg.strip()
|
arg = arg.strip()
|
||||||
|
@ -849,6 +849,7 @@ def compile(cache_path, jittor_path):
|
||||||
headers = [ os.path.join(jittor_path, h) for h in headers1 ] + \
|
headers = [ os.path.join(jittor_path, h) for h in headers1 ] + \
|
||||||
[ os.path.join(cache_path, h) for h in headers2 ]
|
[ os.path.join(cache_path, h) for h in headers2 ]
|
||||||
basenames = []
|
basenames = []
|
||||||
|
pyjt_names = []
|
||||||
for h in headers:
|
for h in headers:
|
||||||
with open(h, 'r') as f:
|
with open(h, 'r') as f:
|
||||||
src = f.read()
|
src = f.read()
|
||||||
|
@ -866,6 +867,7 @@ def compile(cache_path, jittor_path):
|
||||||
if not check: continue
|
if not check: continue
|
||||||
|
|
||||||
basenames.append(basename)
|
basenames.append(basename)
|
||||||
|
pyjt_names.append(fname)
|
||||||
|
|
||||||
code = f"""
|
code = f"""
|
||||||
#include "pyjt/numpy.h"
|
#include "pyjt/numpy.h"
|
||||||
|
@ -888,3 +890,5 @@ def compile(cache_path, jittor_path):
|
||||||
LOG.vvvv(code)
|
LOG.vvvv(code)
|
||||||
with open(fname, "w") as f:
|
with open(fname, "w") as f:
|
||||||
f.write(code)
|
f.write(code)
|
||||||
|
pyjt_names.append(fname)
|
||||||
|
return pyjt_names
|
||||||
|
|
|
@ -60,6 +60,7 @@ class TestArray(unittest.TestCase):
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
x = jt.array(im)
|
x = jt.array(im)
|
||||||
b = net(x)
|
b = net(x)
|
||||||
|
b.fetch(lambda b: None)
|
||||||
b.sync()
|
b.sync()
|
||||||
jt.sync(device_sync=True)
|
jt.sync(device_sync=True)
|
||||||
|
|
||||||
|
@ -70,6 +71,7 @@ class TestArray(unittest.TestCase):
|
||||||
x = jt.array(im)
|
x = jt.array(im)
|
||||||
b = net(x)
|
b = net(x)
|
||||||
b.fetch(lambda b: results.append(b))
|
b.fetch(lambda b: results.append(b))
|
||||||
|
b.sync()
|
||||||
# del c
|
# del c
|
||||||
jt.sync(device_sync=True)
|
jt.sync(device_sync=True)
|
||||||
t2 = time.time() - time_start
|
t2 = time.time() - time_start
|
||||||
|
@ -111,6 +113,12 @@ class TestArray(unittest.TestCase):
|
||||||
""")
|
""")
|
||||||
assert (b.data==[2,8,18]).all()
|
assert (b.data==[2,8,18]).all()
|
||||||
|
|
||||||
|
def test_not_c_style(self):
|
||||||
|
a = np.array([1,2,3])
|
||||||
|
b = a[::-1]
|
||||||
|
x = jt.array(b)
|
||||||
|
x = x + b
|
||||||
|
assert (x.data == [6,4,2]).all()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ def expect_error(func):
|
||||||
|
|
||||||
class TestCore(unittest.TestCase):
|
class TestCore(unittest.TestCase):
|
||||||
def test_number_of_hold_vars(self):
|
def test_number_of_hold_vars(self):
|
||||||
assert jt.random([1,2,3]).peek() == "float[1,2,3,]"
|
assert jt.random([1,2,3]).peek() == "float32[1,2,3,]"
|
||||||
assert jt.core.number_of_hold_vars() == 0
|
assert jt.core.number_of_hold_vars() == 0
|
||||||
x = jt.random([1,2,3])
|
x = jt.random([1,2,3])
|
||||||
assert jt.core.number_of_hold_vars() == 1
|
assert jt.core.number_of_hold_vars() == 1
|
||||||
|
|
|
@ -16,7 +16,7 @@ namespace jittor {
|
||||||
|
|
||||||
struct CustomOp : Op {
|
struct CustomOp : Op {
|
||||||
Var* output;
|
Var* output;
|
||||||
CustomOp(NanoVector shape, NanoString dtype=ns_float);
|
CustomOp(NanoVector shape, NanoString dtype=ns_float32);
|
||||||
|
|
||||||
const char* name() const override { return "custom"; }
|
const char* name() const override { return "custom"; }
|
||||||
DECLARE_jit_run;
|
DECLARE_jit_run;
|
||||||
|
@ -75,7 +75,7 @@ class TestCustomOp(unittest.TestCase):
|
||||||
my_op = jt.compile_custom_op("""
|
my_op = jt.compile_custom_op("""
|
||||||
struct MyOp : Op {
|
struct MyOp : Op {
|
||||||
Var* output;
|
Var* output;
|
||||||
MyOp(NanoVector shape, NanoString dtype=ns_float);
|
MyOp(NanoVector shape, NanoString dtype=ns_float32);
|
||||||
|
|
||||||
const char* name() const override { return "my"; }
|
const char* name() const override { return "my"; }
|
||||||
DECLARE_jit_run;
|
DECLARE_jit_run;
|
||||||
|
|
|
@ -13,7 +13,10 @@ class TestFetcher(unittest.TestCase):
|
||||||
a = jt.array([1,2,3])
|
a = jt.array([1,2,3])
|
||||||
a = a*2
|
a = a*2
|
||||||
v = []
|
v = []
|
||||||
jt.fetch([a], lambda a: v.append(a))
|
jt.fetch(a, lambda a: v.append(a))
|
||||||
|
jt.fetch(1, 2, 3, a,
|
||||||
|
lambda x, y, z, a: self.assertTrue(x==1 and y==2 and z==3 and isinstance(a, np.ndarray))
|
||||||
|
)
|
||||||
jt.sync_all(True)
|
jt.sync_all(True)
|
||||||
assert len(v)==1 and (v[0]==[2,4,6]).all()
|
assert len(v)==1 and (v[0]==[2,4,6]).all()
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from .test_core import expect_error
|
||||||
import os
|
import os
|
||||||
|
|
||||||
mid = 0
|
mid = 0
|
||||||
if os.uname()[1] == "jittor-ce":
|
if "jittor" in os.uname()[1]:
|
||||||
mid = 1
|
mid = 1
|
||||||
|
|
||||||
class TestNanoString(unittest.TestCase):
|
class TestNanoString(unittest.TestCase):
|
||||||
|
@ -27,7 +27,8 @@ class TestNanoString(unittest.TestCase):
|
||||||
assert t < [1.5e-7, 1.7e-7][mid], t
|
assert t < [1.5e-7, 1.7e-7][mid], t
|
||||||
|
|
||||||
assert (jt.hash("asdasd") == 4152566416)
|
assert (jt.hash("asdasd") == 4152566416)
|
||||||
assert str(jt.NanoString("float"))=="float"
|
assert str(jt.NanoString("float"))=="float32"
|
||||||
|
assert jt.NanoString("float")=="float32"
|
||||||
# pybind11: 7
|
# pybind11: 7
|
||||||
# Tuple call: 1.3
|
# Tuple call: 1.3
|
||||||
# fast call (with or with not): 0.9
|
# fast call (with or with not): 0.9
|
||||||
|
@ -38,14 +39,14 @@ class TestNanoString(unittest.TestCase):
|
||||||
|
|
||||||
def test_type(self):
|
def test_type(self):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
assert str(jt.NanoString(float)) == "float"
|
assert str(jt.NanoString(float)) == "float32"
|
||||||
assert str(jt.NanoString(np.float)) == "float"
|
assert str(jt.NanoString(np.float)) == "float32"
|
||||||
assert str(jt.NanoString(np.float32)) == "float32"
|
assert str(jt.NanoString(np.float32)) == "float32"
|
||||||
assert str(jt.NanoString(np.float64)) == "float64"
|
assert str(jt.NanoString(np.float64)) == "float64"
|
||||||
assert str(jt.NanoString(np.int8)) == "int8"
|
assert str(jt.NanoString(np.int8)) == "int8"
|
||||||
assert str(jt.NanoString(np.array([1,2,3]).dtype)) == "int64"
|
assert str(jt.NanoString(np.array([1,2,3]).dtype)) == "int64"
|
||||||
|
|
||||||
assert str(jt.NanoString(jt.float)) == "float"
|
assert str(jt.NanoString(jt.float)) == "float32"
|
||||||
assert str(jt.NanoString(jt.float32)) == "float32"
|
assert str(jt.NanoString(jt.float32)) == "float32"
|
||||||
assert str(jt.NanoString(jt.float64)) == "float64"
|
assert str(jt.NanoString(jt.float64)) == "float64"
|
||||||
assert str(jt.NanoString(jt.int8)) == "int8"
|
assert str(jt.NanoString(jt.int8)) == "int8"
|
||||||
|
|
|
@ -99,6 +99,7 @@ class TestResizeAndCrop(unittest.TestCase):
|
||||||
test_case(20, [1024, 1024], [1.2, 1.8][mid])
|
test_case(20, [1024, 1024], [1.2, 1.8][mid])
|
||||||
test_case(20, [1024, 666], [0.8,1.0][mid])
|
test_case(20, [1024, 666], [0.8,1.0][mid])
|
||||||
|
|
||||||
|
@unittest.skipIf(torch is None, "no torch found")
|
||||||
def test_resize(self):
|
def test_resize(self):
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
x = np.array(range(2*3*25)).reshape(2,3,5,5).astype("float32")
|
x = np.array(range(2*3*25)).reshape(2,3,5,5).astype("float32")
|
||||||
|
@ -108,11 +109,13 @@ class TestResizeAndCrop(unittest.TestCase):
|
||||||
jnn.Resize((r_size, r_size), 'bilinear', align_corners),
|
jnn.Resize((r_size, r_size), 'bilinear', align_corners),
|
||||||
lambda x: F.interpolate(x, size=(r_size, r_size), mode='bilinear',align_corners=align_corners))
|
lambda x: F.interpolate(x, size=(r_size, r_size), mode='bilinear',align_corners=align_corners))
|
||||||
|
|
||||||
|
@unittest.skipIf(torch is None, "no torch found")
|
||||||
def test_upsample(self):
|
def test_upsample(self):
|
||||||
arr = np.random.randn(2,3,224,224)
|
arr = np.random.randn(2,3,224,224)
|
||||||
check_equal(arr, jnn.Upsample(scale_factor=2), tnn.Upsample(scale_factor=2))
|
check_equal(arr, jnn.Upsample(scale_factor=2), tnn.Upsample(scale_factor=2))
|
||||||
check_equal(arr, jnn.Upsample(scale_factor=0.2), tnn.Upsample(scale_factor=0.2))
|
check_equal(arr, jnn.Upsample(scale_factor=0.2), tnn.Upsample(scale_factor=0.2))
|
||||||
|
|
||||||
|
@unittest.skipIf(torch is None, "no torch found")
|
||||||
def test_pixelshuffle(self):
|
def test_pixelshuffle(self):
|
||||||
arr = np.random.randn(2,4,224,224)
|
arr = np.random.randn(2,4,224,224)
|
||||||
check_equal(arr, jnn.PixelShuffle(upscale_factor=2), tnn.PixelShuffle(upscale_factor=2))
|
check_equal(arr, jnn.PixelShuffle(upscale_factor=2), tnn.PixelShuffle(upscale_factor=2))
|
||||||
|
|
|
@ -64,16 +64,16 @@ class TestResnet(unittest.TestCase):
|
||||||
SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay)
|
SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay)
|
||||||
|
|
||||||
for batch_idx, (data, target) in enumerate(self.train_loader):
|
for batch_idx, (data, target) in enumerate(self.train_loader):
|
||||||
output = mnist_net(data)
|
|
||||||
loss = nn.cross_entropy_loss(output, target)
|
|
||||||
|
|
||||||
# train step
|
# train step
|
||||||
with jt.log_capture_scope(
|
with jt.log_capture_scope(
|
||||||
log_silent=1,
|
log_silent=1,
|
||||||
log_v=1, log_vprefix="op.cc=100,exe=10",
|
log_v=1, log_vprefix="op.cc=100,exe=10",
|
||||||
) as logs:
|
) as logs:
|
||||||
|
output = mnist_net(data)
|
||||||
|
loss = nn.cross_entropy_loss(output, target)
|
||||||
SGD.step(loss)
|
SGD.step(loss)
|
||||||
def callback(loss, output, target, batch_idx):
|
def callback(batch_idx, loss, output, target):
|
||||||
# print train info
|
# print train info
|
||||||
global prev
|
global prev
|
||||||
pred = np.argmax(output, axis=1)
|
pred = np.argmax(output, axis=1)
|
||||||
|
@ -83,13 +83,13 @@ class TestResnet(unittest.TestCase):
|
||||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
|
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
|
||||||
.format(0, batch_idx, 600,1. * batch_idx / 6.0, loss[0], acc, time.time()-prev))
|
.format(0, batch_idx, 600,1. * batch_idx / 6.0, loss[0], acc, time.time()-prev))
|
||||||
# prev = time.time()
|
# prev = time.time()
|
||||||
jt.fetch([loss, output, target], callback, batch_idx)
|
jt.fetch(batch_idx, loss, output, target, callback)
|
||||||
|
|
||||||
log_conv = find_log_with_re(logs,
|
log_conv = find_log_with_re(logs,
|
||||||
"Jit op key (not )?found: ((mkl)|(cudnn))_conv.*")
|
"Jit op key (not )?found: ((mkl)|(cudnn))_conv.*")
|
||||||
log_matmul = find_log_with_re(logs,
|
log_matmul = find_log_with_re(logs,
|
||||||
"Jit op key (not )?found: ((mkl)|(cublas))_matmul.*")
|
"Jit op key (not )?found: ((mkl)|(cublas))_matmul.*")
|
||||||
if batch_idx:
|
if batch_idx > 2:
|
||||||
assert len(log_conv)==59 and len(log_matmul)==6, (len(log_conv), len(log_matmul))
|
assert len(log_conv)==59 and len(log_matmul)==6, (len(log_conv), len(log_matmul))
|
||||||
|
|
||||||
mem_used = jt.flags.stat_allocator_total_alloc_byte \
|
mem_used = jt.flags.stat_allocator_total_alloc_byte \
|
||||||
|
@ -114,15 +114,13 @@ class TestResnet(unittest.TestCase):
|
||||||
# Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000
|
# Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000
|
||||||
# Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000
|
# Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000
|
||||||
|
|
||||||
# print(jt.core.number_of_lived_vars(), mem_used)
|
if jt.in_mpi:
|
||||||
jt.display_memory_info()
|
assert jt.core.number_of_lived_vars() < 7500, jt.core.number_of_lived_vars()
|
||||||
# if jt.in_mpi:
|
else:
|
||||||
# assert jt.core.number_of_lived_vars() < 3900, jt.core.number_of_lived_vars()
|
assert jt.core.number_of_lived_vars() < 6500, jt.core.number_of_lived_vars()
|
||||||
# else:
|
|
||||||
# assert jt.core.number_of_lived_vars() < 3500, jt.core.number_of_lived_vars()
|
|
||||||
|
|
||||||
jt.sync_all(True)
|
jt.sync_all(True)
|
||||||
assert np.mean(loss_list[-50:])<0.3
|
assert np.mean(loss_list[-50:])<0.5
|
||||||
assert np.mean(acc_list[-50:])>0.8
|
assert np.mean(acc_list[-50:])>0.8
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -77,7 +77,7 @@ class TestVGGClass(unittest.TestCase):
|
||||||
acc_list.append(acc)
|
acc_list.append(acc)
|
||||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f}'
|
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f}'
|
||||||
.format(0, batch_idx, 100,1. * batch_idx, loss[0], acc))
|
.format(0, batch_idx, 100,1. * batch_idx, loss[0], acc))
|
||||||
jt.fetch([loss, output, target], callback, batch_idx)
|
jt.fetch(batch_idx, loss, output, target, callback)
|
||||||
|
|
||||||
log_conv = find_log_with_re(logs,
|
log_conv = find_log_with_re(logs,
|
||||||
"Jit op key (not )?found: ((mkl)|(cudnn))_conv.*")
|
"Jit op key (not )?found: ((mkl)|(cudnn))_conv.*")
|
||||||
|
|
|
@ -25,8 +25,8 @@ def docker_task(name, build_cmd):
|
||||||
run_cmd(build_cmd)
|
run_cmd(build_cmd)
|
||||||
run_cmd(f"sudo docker push {name}")
|
run_cmd(f"sudo docker push {name}")
|
||||||
bname = os.path.basename(name)
|
bname = os.path.basename(name)
|
||||||
run_cmd(f"docker save {name}:latest -o /tmp/{bname}.tgz && chmod 666 /tmp/{bname}.tgz")
|
run_cmd(f"sudo docker save {name}:latest -o /tmp/{bname}.tgz && sudo chmod 666 /tmp/{bname}.tgz")
|
||||||
upload_file(f" /tmp/{bname}.tgz")
|
upload_file(f"/tmp/{bname}.tgz")
|
||||||
|
|
||||||
docker_task(
|
docker_task(
|
||||||
"jittor/jittor",
|
"jittor/jittor",
|
||||||
|
|
|
@ -0,0 +1,117 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Copyright 2013 Benedikt Morbach <moben@exherbo.org>
|
||||||
|
# Distributed under the terms of the GNU General Public License v2
|
||||||
|
|
||||||
|
# runs multiple MPI processes as a grid in a new tmux window and multiplexes keyboard input to all of them
|
||||||
|
|
||||||
|
additional_vars=( LD_LIBRARY_PATH LD_PRELOAD )
|
||||||
|
export "${additional_vars[@]}"
|
||||||
|
|
||||||
|
usage() {
|
||||||
|
echo 'tmpi: Run multiple MPI processes as a grid in a new tmux window and multiplex keyboard input to all of them.'
|
||||||
|
echo ''
|
||||||
|
echo 'Usage:'
|
||||||
|
echo ' tmpi [number] [command]'
|
||||||
|
echo ''
|
||||||
|
echo 'You need to pass at least two arguments.'
|
||||||
|
echo 'The first argument is the number of processes to use, every argument after that is the commandline to run.'
|
||||||
|
echo 'If you call this script from outside tmux and your command contains important whitespace then you need to appy two levels of quoting to preserve it.'
|
||||||
|
echo ''
|
||||||
|
echo 'LD_LIBRARY_PATH and LD_PRELOAD are passed through, so you can run it like this:'
|
||||||
|
echo 'LD_LIBRARY_PATH="${PWD}/.libs:${LD_LIBRARY_PATH}" tmpi 16 gdb -q bin/.libs/example'
|
||||||
|
echo ''
|
||||||
|
echo 'The new window is set to remain on exit and has to be closed manually. ("C-b + k" by default)'
|
||||||
|
}
|
||||||
|
|
||||||
|
check_tools() {
|
||||||
|
tools=( tmux mpirun )
|
||||||
|
|
||||||
|
for tool in "${tools[@]}"; do
|
||||||
|
if ! which ${tool}; then
|
||||||
|
echo "You need to install ${tool} to run this script."
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
if [[ ${#} -lt 2 ]]; then
|
||||||
|
usage
|
||||||
|
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -z ${TMUX} ]]; then
|
||||||
|
# it seems we aren't in a tmux session.
|
||||||
|
# start a new one so that our window doesn't end up in some other session and we have to search it.
|
||||||
|
# actually start a new server with '-L' to ensure that our environment carries over.
|
||||||
|
socket=$(mktemp --dry-run tmpi.XXXX)
|
||||||
|
exec tmux -L ${socket} new-session "${0} ${*}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ ${1} == runmpi ]] ; then
|
||||||
|
# we are being started as one of many processes by mpirun.
|
||||||
|
shift
|
||||||
|
|
||||||
|
# start the processes in the order of their rank.
|
||||||
|
# this avoids races, as we have to push the variables in tmux' environment.
|
||||||
|
# it has the nice side-effect that the panes are also ordered by rank.
|
||||||
|
while [[ $(cat /tmp/tmpi.lock) -ne ${OMPI_COMM_WORLD_RANK} ]] ; do
|
||||||
|
sleep 0.02
|
||||||
|
done
|
||||||
|
|
||||||
|
# get all the variables that mpirun starts us with so that we can pass them through.
|
||||||
|
mpi_vars=( $( env | grep -e MPI -e OPAL -e PMIX -e PYTHON -e debug | cut -d '=' -f1 ) )
|
||||||
|
mpi_vars+=( "${additional_vars[@]}" )
|
||||||
|
|
||||||
|
# add the variables to tmux' session environment.
|
||||||
|
# we can't just export them because the process will be started as a child of tmux, not us.
|
||||||
|
for var in "${mpi_vars[@]}"; do
|
||||||
|
tmux set-environment -t ${session} "${var}" "${!var}"
|
||||||
|
done
|
||||||
|
|
||||||
|
x=( $(tmux split-window -P -F '#{pane_pid} #{pane_id}' -t ${window} "${*}") )
|
||||||
|
pid=${x[0]}
|
||||||
|
pane=${x[1]}
|
||||||
|
|
||||||
|
for var in "${mpi_vars[@]}"; do
|
||||||
|
tmux set-environment -t ${session} -u "${var}"
|
||||||
|
done
|
||||||
|
|
||||||
|
# kill the dummy pane that opened the new window
|
||||||
|
[[ ${OMPI_COMM_WORLD_RANK} -eq 0 ]] && tmux kill-pane -t ${dummy} &> /dev/null
|
||||||
|
|
||||||
|
# set the window to tiled mode.
|
||||||
|
# have to do this after every new pane is spawned because otherwise the splits get
|
||||||
|
# smaller and smaller until tmux refuses to open new panes, despite plenty of space being left.
|
||||||
|
tmux select-layout -t ${pane} tiled &> /dev/null
|
||||||
|
|
||||||
|
# let the next process start
|
||||||
|
echo $((${OMPI_COMM_WORLD_RANK}+1)) > /tmp/tmpi.lock
|
||||||
|
|
||||||
|
# don't exit here as mpirun needs to be kept alive and it would also exit.
|
||||||
|
while [[ -d /proc/${pid} ]]; do
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
else
|
||||||
|
# we are the parent and set everything up before we start ourselves a bunch of times via mpirun.
|
||||||
|
processes=${1}
|
||||||
|
self=${0}
|
||||||
|
shift
|
||||||
|
|
||||||
|
# create an empty new dummy window which we sill later split up for the mpi processes.
|
||||||
|
x=( $(tmux new-window ${session} -P -F '#{pane_id} #{window_id} #{session_id}') )
|
||||||
|
export dummy=${x[0]}
|
||||||
|
export window=${x[1]}
|
||||||
|
export session=${x[2]}
|
||||||
|
|
||||||
|
# syncronize input to all panes.
|
||||||
|
tmux set-window-option -t ${window} synchronize-panes on &> /dev/null
|
||||||
|
tmux set-window-option -t ${window} remain-on-exit on &> /dev/null
|
||||||
|
|
||||||
|
# always start with rank 0.
|
||||||
|
echo 0 > /tmp/tmpi.lock
|
||||||
|
|
||||||
|
# re-execute ourself to spawn of the processes.
|
||||||
|
echo mpirun -np ${processes} ${self} runmpi "${@}"
|
||||||
|
mpirun -np ${processes} ${self} runmpi "${@}"
|
||||||
|
fi
|
2
setup.py
2
setup.py
|
@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh:
|
||||||
|
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name='jittor',
|
name='jittor',
|
||||||
version='1.1.4.9',
|
version='1.1.5.4',
|
||||||
# scripts=[],
|
# scripts=[],
|
||||||
author="Jittor Group",
|
author="Jittor Group",
|
||||||
author_email="ran.donglang@gmail.com",
|
author_email="ran.donglang@gmail.com",
|
||||||
|
|
|
@ -9,7 +9,6 @@
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <helper_cuda.h>
|
#include <helper_cuda.h>
|
||||||
#include "mem/allocator/cuda_dual_allocator.h"
|
#include "mem/allocator/cuda_dual_allocator.h"
|
||||||
#include "fetcher.h"
|
|
||||||
#include "event_queue.h"
|
#include "event_queue.h"
|
||||||
#endif
|
#endif
|
||||||
#include "misc/cuda_flags.h"
|
#include "misc/cuda_flags.h"
|
||||||
|
@ -26,6 +25,9 @@ namespace jittor {
|
||||||
|
|
||||||
Executor exe;
|
Executor exe;
|
||||||
|
|
||||||
|
// from fetch_op.cc
|
||||||
|
extern list<VarPtr> fetcher_to_free;
|
||||||
|
|
||||||
void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
||||||
auto allocator = get_allocator();
|
auto allocator = get_allocator();
|
||||||
this->allocator = allocator;
|
this->allocator = allocator;
|
||||||
|
@ -33,22 +35,43 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
||||||
int op_num = 0;
|
int op_num = 0;
|
||||||
vector<Node*> bfs_q;
|
vector<Node*> bfs_q;
|
||||||
bfs_q.reserve(vars.size());
|
bfs_q.reserve(vars.size());
|
||||||
auto nodes = (vector<Node*>*)&vars;
|
|
||||||
int start_var_num = 0;
|
int start_var_num = 0;
|
||||||
for (Var* v : vars)
|
{
|
||||||
if (!v->is_finished())
|
// get all nodes need to be executed
|
||||||
start_var_num++;
|
auto t = ++Node::tflag_count;
|
||||||
bfs_backward(*nodes, bfs_q, [&](Node *node) -> bool {
|
for (Var* v : vars)
|
||||||
node->custom_data = 0;
|
if (!v->is_finished() && v->tflag != t) {
|
||||||
if (node->is_finished())
|
v->tflag = t;
|
||||||
return false;
|
start_var_num++;
|
||||||
op_num += !node->is_var();
|
bfs_q.push_back(v);
|
||||||
return true;
|
}
|
||||||
});
|
for (int i=0; i<bfs_q.size(); i++) {
|
||||||
|
auto node = bfs_q[i];
|
||||||
|
op_num += !node->is_var();
|
||||||
|
for (auto i : node->_inputs)
|
||||||
|
if (i.node->tflag != t && !i.node->is_finished()) {
|
||||||
|
i.node->tflag = t;
|
||||||
|
bfs_q.push_back(i.node);
|
||||||
|
}
|
||||||
|
// this var has been fetched
|
||||||
|
if (node->flags.get(NodeFlags::_fetch)) {
|
||||||
|
for (auto& n : node->_outputs) {
|
||||||
|
// if not in queue and is fetch op
|
||||||
|
if (n.node->tflag != t &&
|
||||||
|
!n.node->is_finished() &&
|
||||||
|
n.node->flags.get(NodeFlags::_fetch)) {
|
||||||
|
n.node->tflag = t;
|
||||||
|
bfs_q.push_back(n.node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
auto tt = Node::tflag_count;
|
auto tt = Node::tflag_count;
|
||||||
vector<Op*> ops;
|
vector<Op*> ops;
|
||||||
vector<Var*> all_vars;
|
vector<Var*> all_vars;
|
||||||
ops.reserve(op_num);
|
ops.reserve(op_num);
|
||||||
|
all_vars.reserve(bfs_q.size() - op_num);
|
||||||
for (Node* node : bfs_q)
|
for (Node* node : bfs_q)
|
||||||
if (!node->is_var()) {
|
if (!node->is_var()) {
|
||||||
node->custom_data = ops.size();
|
node->custom_data = ops.size();
|
||||||
|
@ -391,7 +414,6 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
||||||
outputs_bk.push_back(var);
|
outputs_bk.push_back(var);
|
||||||
op->finish_pending_liveness();
|
op->finish_pending_liveness();
|
||||||
for (Var* var : outputs_bk)
|
for (Var* var : outputs_bk)
|
||||||
// var->finish_pending_liveness();
|
|
||||||
var->finish_pending_liveness();
|
var->finish_pending_liveness();
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
// log memory info
|
// log memory info
|
||||||
|
@ -410,6 +432,8 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
||||||
}
|
}
|
||||||
LOGvv << "All" << op_num << "ops finished, return vars:" << vars;
|
LOGvv << "All" << op_num << "ops finished, return vars:" << vars;
|
||||||
for (Var* v : vars) ASSERT(v->mem_ptr);
|
for (Var* v : vars) ASSERT(v->mem_ptr);
|
||||||
|
// clean fetcher free buffer
|
||||||
|
fetcher_to_free.clear();
|
||||||
#ifdef HAS_CUDA
|
#ifdef HAS_CUDA
|
||||||
if (device_sync && use_cuda) {
|
if (device_sync && use_cuda) {
|
||||||
last_is_cuda = false;
|
last_is_cuda = false;
|
||||||
|
|
47
src/grad.cc
47
src/grad.cc
|
@ -27,7 +27,7 @@ VarPtr make_grad(Op* op, Var* out, Var* dout, Var* x, int x_index) {
|
||||||
auto dx = op->grad(out, dout, x, x_index);
|
auto dx = op->grad(out, dout, x, x_index);
|
||||||
if (x->loop_options)
|
if (x->loop_options)
|
||||||
dx->loop_options = x->loop_options;
|
dx->loop_options = x->loop_options;
|
||||||
return move(dx);
|
return dx;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline static void assign_attrs(Var* a, Var* b) {
|
inline static void assign_attrs(Var* a, Var* b) {
|
||||||
|
@ -92,29 +92,30 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
||||||
Op* op = it.op;
|
Op* op = it.op;
|
||||||
auto index = it.index;
|
auto index = it.index;
|
||||||
if (op->tflag != nt) continue;
|
if (op->tflag != nt) continue;
|
||||||
// TODO: support two outputs backprop.
|
for (Var* out : op->outputs()) {
|
||||||
Var* out = op->outputs().back();
|
if (out->tflag != nt) continue;
|
||||||
Var* dout = grads[out->custom_data];
|
Var* dout = grads[out->custom_data];
|
||||||
VarPtr dvar = make_grad(op, out, dout, var, index);
|
VarPtr dvar = make_grad(op, out, dout, var, index);
|
||||||
registe_node_trace_grad(dvar.ptr, op, index);
|
registe_node_trace_grad(dvar.ptr, op, index);
|
||||||
if (dvar)
|
if (dvar)
|
||||||
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
|
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
|
||||||
<< "dvar" << dvar << "var" << var;
|
<< "dvar" << dvar << "var" << var;
|
||||||
if (!grad)
|
if (!grad)
|
||||||
grad = move(dvar);
|
grad = move(dvar);
|
||||||
else if (dvar) {
|
else if (dvar) {
|
||||||
grad = make_binary(grad, dvar, ns_add);
|
grad = make_binary(grad, dvar, ns_add);
|
||||||
#ifdef PREVENT_LARGE_FUSED_OP
|
#ifdef PREVENT_LARGE_FUSED_OP
|
||||||
gsum ++;
|
gsum ++;
|
||||||
if (gsum>=PREVENT_LARGE_FUSED_OP) {
|
if (gsum>=PREVENT_LARGE_FUSED_OP) {
|
||||||
// TODO: this is a dirty fix for
|
// TODO: this is a dirty fix for
|
||||||
// stopping fuse lots of op together,
|
// stopping fuse lots of op together,
|
||||||
// try to find a better solution
|
// try to find a better solution
|
||||||
grad->flags.set(NodeFlags::_stop_fuse);
|
grad->flags.set(NodeFlags::_stop_fuse);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
assign_attrs(grad.ptr, var);
|
||||||
|
registe_node_trace_grad(grad.ptr, var, index);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
assign_attrs(grad.ptr, var);
|
|
||||||
registe_node_trace_grad(grad.ptr, var, index);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
10
src/init.cc
10
src/init.cc
|
@ -11,6 +11,7 @@
|
||||||
|
|
||||||
#include "init.h"
|
#include "init.h"
|
||||||
#include "ops/op_register.h"
|
#include "ops/op_register.h"
|
||||||
|
#include "var.h"
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
|
@ -21,6 +22,15 @@ unique_ptr<std::default_random_engine> eng;
|
||||||
vector<set_seed_callback> callbacks;
|
vector<set_seed_callback> callbacks;
|
||||||
int current_seed;
|
int current_seed;
|
||||||
|
|
||||||
|
// fron fetch_op.cc
|
||||||
|
extern list<VarPtr> fetcher;
|
||||||
|
extern list<VarPtr> fetcher_to_free;
|
||||||
|
|
||||||
|
void cleanup() {
|
||||||
|
fetcher_to_free.clear();
|
||||||
|
fetcher.clear();
|
||||||
|
}
|
||||||
|
|
||||||
static void init_cuda_devices() {
|
static void init_cuda_devices() {
|
||||||
#ifdef HAS_CUDA
|
#ifdef HAS_CUDA
|
||||||
int count=0;
|
int count=0;
|
||||||
|
|
|
@ -20,4 +20,8 @@ void add_set_seed_callback(set_seed_callback callback);
|
||||||
extern "C"
|
extern "C"
|
||||||
std::default_random_engine* get_random_engine();
|
std::default_random_engine* get_random_engine();
|
||||||
|
|
||||||
|
// things need to be clean before python exit
|
||||||
|
// @pyjt(cleanup)
|
||||||
|
void cleanup();
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
||||||
|
|
|
@ -95,7 +95,7 @@ struct DelayFree final : Allocator {
|
||||||
void free(void* mem_ptr, size_t size, const size_t& allocation) override {
|
void free(void* mem_ptr, size_t size, const size_t& allocation) override {
|
||||||
using namespace cuda_dual_local;
|
using namespace cuda_dual_local;
|
||||||
allocations.emplace_back(mem_ptr, allocation, size, &cuda_dual_allocator);
|
allocations.emplace_back(mem_ptr, allocation, size, &cuda_dual_allocator);
|
||||||
checkCudaErrors(_cudaLaunchHostFunc(0, &to_free_allocation, 0));
|
peekCudaErrors(_cudaLaunchHostFunc(0, &to_free_allocation, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
void migrate_to_cpu(void*& mem_ptr, size_t& allocation, size_t size, Allocator* allocator) {
|
void migrate_to_cpu(void*& mem_ptr, size_t& allocation, size_t size, Allocator* allocator) {
|
||||||
|
|
|
@ -9,9 +9,6 @@
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
#define FOR_ALL_TYPES(m) \
|
#define FOR_ALL_TYPES(m) \
|
||||||
m(float) \
|
|
||||||
m(double) \
|
|
||||||
m(int) \
|
|
||||||
m(bool) \
|
m(bool) \
|
||||||
m(int8) \
|
m(int8) \
|
||||||
m(int16) \
|
m(int16) \
|
||||||
|
@ -151,6 +148,10 @@ static void init_ns() {
|
||||||
NanoString::__string_to_ns["sum"] = ns_add;
|
NanoString::__string_to_ns["sum"] = ns_add;
|
||||||
NanoString::__string_to_ns["min"] = ns_minimum;
|
NanoString::__string_to_ns["min"] = ns_minimum;
|
||||||
NanoString::__string_to_ns["max"] = ns_maximum;
|
NanoString::__string_to_ns["max"] = ns_maximum;
|
||||||
|
NanoString::__string_to_ns["float"] = ns_float32;
|
||||||
|
NanoString::__string_to_ns["double"] = ns_float64;
|
||||||
|
NanoString::__string_to_ns["int"] = ns_int32;
|
||||||
|
NanoString::__string_to_ns["uint"] = ns_uint32;
|
||||||
LOGvv << "init __string_to_ns" << NanoString::__string_to_ns;
|
LOGvv << "init __string_to_ns" << NanoString::__string_to_ns;
|
||||||
LOGvv << "init __ns_to_string" << NanoString::__ns_to_string;
|
LOGvv << "init __ns_to_string" << NanoString::__ns_to_string;
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,9 +12,6 @@ namespace jittor {
|
||||||
#define FOR_ALL_NS(m) \
|
#define FOR_ALL_NS(m) \
|
||||||
\
|
\
|
||||||
m(void) \
|
m(void) \
|
||||||
m(float) \
|
|
||||||
m(double) \
|
|
||||||
m(int) \
|
|
||||||
m(bool) \
|
m(bool) \
|
||||||
m(int8) \
|
m(int8) \
|
||||||
m(int16) \
|
m(int16) \
|
||||||
|
|
|
@ -24,7 +24,9 @@ struct NodeFlags {
|
||||||
_finished=1,
|
_finished=1,
|
||||||
// bit2: stop grad
|
// bit2: stop grad
|
||||||
_stop_grad=2,
|
_stop_grad=2,
|
||||||
_n=3,
|
// bit3: is fetch
|
||||||
|
_fetch=3,
|
||||||
|
_n=4,
|
||||||
|
|
||||||
// var related flags
|
// var related flags
|
||||||
_force_fuse=_n+0,
|
_force_fuse=_n+0,
|
||||||
|
|
|
@ -32,9 +32,9 @@ Init() {
|
||||||
}
|
}
|
||||||
~Init() {
|
~Init() {
|
||||||
if (!get_device_count()) return;
|
if (!get_device_count()) return;
|
||||||
checkCudaErrors(cudaDeviceSynchronize());
|
peekCudaErrors(cudaDeviceSynchronize());
|
||||||
checkCudaErrors(cudaStreamDestroy(stream));
|
peekCudaErrors(cudaStreamDestroy(stream));
|
||||||
checkCudaErrors(cudaEventDestroy(event));
|
peekCudaErrors(cudaEventDestroy(event));
|
||||||
}
|
}
|
||||||
} init;
|
} init;
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ struct ArrayOp : Op {
|
||||||
Var* output;
|
Var* output;
|
||||||
Allocation allocation;
|
Allocation allocation;
|
||||||
// @pybind(None)
|
// @pybind(None)
|
||||||
ArrayOp(const void* ptr, NanoVector shape, NanoString dtype=ns_float);
|
ArrayOp(const void* ptr, NanoVector shape, NanoString dtype=ns_float32);
|
||||||
|
|
||||||
ArrayOp(ArrayArgs&& args);
|
ArrayOp(ArrayArgs&& args);
|
||||||
template<class T>
|
template<class T>
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
// Copyright (c) 2020 Jittor.
|
||||||
|
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||||
|
// All Rights Reserved.
|
||||||
// This file is subject to the terms and conditions defined in
|
// This file is subject to the terms and conditions defined in
|
||||||
// file 'LICENSE.txt', which is part of this source code package.
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
|
@ -12,8 +14,9 @@
|
||||||
#include "mem/allocator/cuda_dual_allocator.h"
|
#include "mem/allocator/cuda_dual_allocator.h"
|
||||||
#include "event_queue.h"
|
#include "event_queue.h"
|
||||||
#endif
|
#endif
|
||||||
#include "fetcher.h"
|
#include "ops/fetch_op.h"
|
||||||
#include "mem/allocator.h"
|
#include "mem/allocator.h"
|
||||||
|
#include "executor.h"
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
|
@ -49,31 +52,68 @@ Init() {
|
||||||
// do not call deleter on exit
|
// do not call deleter on exit
|
||||||
for (auto& f : fetch_tasks)
|
for (auto& f : fetch_tasks)
|
||||||
f.func.deleter = nullptr;
|
f.func.deleter = nullptr;
|
||||||
checkCudaErrors(cudaDeviceSynchronize());
|
peekCudaErrors(cudaDeviceSynchronize());
|
||||||
checkCudaErrors(cudaStreamDestroy(stream));
|
peekCudaErrors(cudaStreamDestroy(stream));
|
||||||
checkCudaErrors(cudaEventDestroy(event));
|
peekCudaErrors(cudaEventDestroy(event));
|
||||||
}
|
}
|
||||||
};
|
} ;
|
||||||
|
|
||||||
}
|
}
|
||||||
using namespace fetcher_local;
|
using namespace fetcher_local;
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void fetch(const vector<VarHolder*>& vh, FetchFunc&& func) {
|
list<VarPtr> fetcher;
|
||||||
|
// this list will be free at each execution
|
||||||
|
list<VarPtr> fetcher_to_free;
|
||||||
|
|
||||||
|
FetchOp::FetchOp(vector<Var*>&& inputs, FetchFunc&& func)
|
||||||
|
: fetch_vars(inputs), func(move(func)) {
|
||||||
#ifdef HAS_CUDA
|
#ifdef HAS_CUDA
|
||||||
static Init init;
|
// stream needs to be created after nccl plugin
|
||||||
|
static Init init_fetch;
|
||||||
#endif
|
#endif
|
||||||
sync(vh);
|
VarPtr vp(0, ns_int32);
|
||||||
vector<Allocation> allocations(vh.size());
|
outputs_holder.emplace_back(vp);
|
||||||
vector<ArrayArgs> arrays(vh.size());
|
fetcher.emplace_front(move(vp));
|
||||||
|
fetcher_iter = fetcher.begin();
|
||||||
|
bool all_finished = true;
|
||||||
|
for (auto v : fetch_vars)
|
||||||
|
if (!v->is_finished()) {
|
||||||
|
all_finished = false;
|
||||||
|
v->flags.set(NodeFlags::_stop_fuse);
|
||||||
|
v->flags.set(NodeFlags::_fetch);
|
||||||
|
}
|
||||||
|
flags.set(NodeFlags::_cpu);
|
||||||
|
flags.set(NodeFlags::_cuda);
|
||||||
|
flags.set(NodeFlags::_fetch);
|
||||||
|
flags.set(NodeFlags::_stop_grad);
|
||||||
|
fetcher_iter->ptr->flags.set(NodeFlags::_fetch);
|
||||||
|
// fetcher_to_free.clear();
|
||||||
|
if (all_finished) {
|
||||||
|
// if all finished, run immediately
|
||||||
|
run();
|
||||||
|
}
|
||||||
|
// if too many fetchers are bufferd, force flush
|
||||||
|
while (fetcher.size() > 20) {
|
||||||
|
LOGvvvv << "too many fetchers(">>fetcher.size() >>
|
||||||
|
") are bufferd, force flush";
|
||||||
|
exe.run_sync({fetcher.back().ptr}, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FetchOp::run() {
|
||||||
|
vector<Allocation> allocations(fetch_vars.size());
|
||||||
|
vector<ArrayArgs> arrays(fetch_vars.size());
|
||||||
#ifdef HAS_CUDA
|
#ifdef HAS_CUDA
|
||||||
bool has_cuda_memcpy = false;
|
bool has_cuda_memcpy = false;
|
||||||
event_queue.flush();
|
event_queue.flush();
|
||||||
#endif
|
#endif
|
||||||
for (int i=0; i<vh.size(); i++) {
|
LOGvvvv << "fetch" << fetch_vars.size() << "vars" << fetch_vars;
|
||||||
auto v = vh[i]->var;
|
int i = 0;
|
||||||
|
for (auto v : fetch_vars) {
|
||||||
auto& allocation = allocations[i];
|
auto& allocation = allocations[i];
|
||||||
|
|
||||||
#ifdef HAS_CUDA
|
#ifdef HAS_CUDA
|
||||||
if (v->allocator->is_cuda()) {
|
if (v->allocator->is_cuda()) {
|
||||||
checkCudaErrors(cudaEventRecord(event, 0));
|
checkCudaErrors(cudaEventRecord(event, 0));
|
||||||
|
@ -98,6 +138,7 @@ void fetch(const vector<VarHolder*>& vh, FetchFunc&& func) {
|
||||||
arrays[i].ptr = allocation.ptr;
|
arrays[i].ptr = allocation.ptr;
|
||||||
arrays[i].shape = v->shape;
|
arrays[i].shape = v->shape;
|
||||||
arrays[i].dtype = v->dtype();
|
arrays[i].dtype = v->dtype();
|
||||||
|
i++;
|
||||||
}
|
}
|
||||||
#ifdef HAS_CUDA
|
#ifdef HAS_CUDA
|
||||||
if (has_cuda_memcpy) {
|
if (has_cuda_memcpy) {
|
||||||
|
@ -109,6 +150,8 @@ void fetch(const vector<VarHolder*>& vh, FetchFunc&& func) {
|
||||||
FetchResult fr{move(func), move(allocations), move(arrays)};
|
FetchResult fr{move(func), move(allocations), move(arrays)};
|
||||||
fr.call();
|
fr.call();
|
||||||
}
|
}
|
||||||
|
fetcher_to_free.emplace_front(move(*fetcher_iter));
|
||||||
|
fetcher.erase(fetcher_iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
|
@ -5,8 +5,9 @@
|
||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include "common.h"
|
#include "op.h"
|
||||||
#include "var_holder.h"
|
#include "var.h"
|
||||||
|
#include "mem/allocator.h"
|
||||||
#include "ops/array_op.h"
|
#include "ops/array_op.h"
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
@ -42,7 +43,15 @@ struct FetchResult {
|
||||||
inline void call() { func.callback(this); }
|
inline void call() { func.callback(this); }
|
||||||
};
|
};
|
||||||
|
|
||||||
// @pyjt(fetch)
|
struct FetchOp final : Op {
|
||||||
void fetch(const vector<VarHolder*>& vh, FetchFunc&& func);
|
vector<Var*> fetch_vars;
|
||||||
|
FetchFunc func;
|
||||||
|
list<VarPtr>::iterator fetcher_iter;
|
||||||
|
|
||||||
} // jittor
|
FetchOp(vector<Var*>&& inputs, FetchFunc&& func);
|
||||||
|
|
||||||
|
const char* name() const override { return "fetch"; }
|
||||||
|
void run() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // jittor
|
|
@ -16,7 +16,7 @@ static auto make_broadcast_to = get_op_info("broadcast_to")
|
||||||
.get_constructor<VarPtr, Var*, Var*, NanoVector>();
|
.get_constructor<VarPtr, Var*, Var*, NanoVector>();
|
||||||
|
|
||||||
VarPtr make_number(float number, Var* x) {
|
VarPtr make_number(float number, Var* x) {
|
||||||
VarPtr nums = make_array(&number, 1, ns_float);
|
VarPtr nums = make_array(&number, 1, ns_float32);
|
||||||
nums = make_broadcast_to(nums, x, {});
|
nums = make_broadcast_to(nums, x, {});
|
||||||
return make_unary(nums, x->dtype());
|
return make_unary(nums, x->dtype());
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,7 @@ namespace jittor {
|
||||||
|
|
||||||
struct RandomOp : Op {
|
struct RandomOp : Op {
|
||||||
Var* output;
|
Var* output;
|
||||||
RandomOp(NanoVector shape, NanoString dtype=ns_float);
|
RandomOp(NanoVector shape, NanoString dtype=ns_float32);
|
||||||
|
|
||||||
const char* name() const override { return "random"; }
|
const char* name() const override { return "random"; }
|
||||||
DECLARE_jit_run;
|
DECLARE_jit_run;
|
||||||
|
|
|
@ -22,9 +22,6 @@ static auto make_number = get_op_info("number")
|
||||||
.get_constructor<VarPtr, float, Var*>();
|
.get_constructor<VarPtr, float, Var*>();
|
||||||
|
|
||||||
static unordered_set<string> unary_ops = {
|
static unordered_set<string> unary_ops = {
|
||||||
"float",
|
|
||||||
"double",
|
|
||||||
"int",
|
|
||||||
"bool",
|
"bool",
|
||||||
"int8",
|
"int8",
|
||||||
"int16",
|
"int16",
|
||||||
|
|
|
@ -229,7 +229,7 @@ void ConvTuner::forwardTune(FusedOp* fop) {
|
||||||
if (!(bop->x->input()->type()==OpType::broadcast && bop->y->input()->type()==OpType::broadcast)) return;
|
if (!(bop->x->input()->type()==OpType::broadcast && bop->y->input()->type()==OpType::broadcast)) return;
|
||||||
|
|
||||||
// only support float32 currently
|
// only support float32 currently
|
||||||
if (bop->z->dtype() != ns_float && bop->z->dtype() != ns_float32)
|
if (bop->z->dtype() != ns_float32)
|
||||||
continue;
|
continue;
|
||||||
Op* ops[3] = {op, bop->x->input(), bop->y->input()};
|
Op* ops[3] = {op, bop->x->input(), bop->y->input()};
|
||||||
int ok = 0;
|
int ok = 0;
|
||||||
|
|
|
@ -15,6 +15,7 @@ PyObject* (*PyArray_New)(PyTypeObject *, int, npy_intp const *, int, npy_intp co
|
||||||
PyObject* (*PyArray_FromAny)(PyObject *, PyArrayDescr_Proxy *, int, int, int, PyObject *);
|
PyObject* (*PyArray_FromAny)(PyObject *, PyArrayDescr_Proxy *, int, int, int, PyObject *);
|
||||||
unsigned int (*PyArray_GetNDArrayCFeatureVersion)();
|
unsigned int (*PyArray_GetNDArrayCFeatureVersion)();
|
||||||
int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj);
|
int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj);
|
||||||
|
PyObject* (*PyArray_NewCopy)(PyObject *, int);
|
||||||
|
|
||||||
tmp_data_t tmp_data;
|
tmp_data_t tmp_data;
|
||||||
|
|
||||||
|
@ -30,6 +31,7 @@ void numpy_init() {
|
||||||
fill(PyArray_New, 93);
|
fill(PyArray_New, 93);
|
||||||
fill(PyArray_GetNDArrayCFeatureVersion, 211);
|
fill(PyArray_GetNDArrayCFeatureVersion, 211);
|
||||||
fill(PyArray_SetBaseObject, 282);
|
fill(PyArray_SetBaseObject, 282);
|
||||||
|
fill(PyArray_NewCopy, 85);
|
||||||
|
|
||||||
ASSERT(PyArray_GetNDArrayCFeatureVersion()>=7);
|
ASSERT(PyArray_GetNDArrayCFeatureVersion()>=7);
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,12 +76,12 @@ inline int get_typenum(NanoString ns) {
|
||||||
if (ns == ns_uint8) return 2;
|
if (ns == ns_uint8) return 2;
|
||||||
if (ns == ns_int16) return 3;
|
if (ns == ns_int16) return 3;
|
||||||
if (ns == ns_uint16) return 4;
|
if (ns == ns_uint16) return 4;
|
||||||
if (ns == ns_int32 || ns == ns_int) return 5;
|
if (ns == ns_int32) return 5;
|
||||||
if (ns == ns_uint32) return 6;
|
if (ns == ns_uint32) return 6;
|
||||||
if (ns == ns_int64) return 7;
|
if (ns == ns_int64) return 7;
|
||||||
if (ns == ns_uint64) return 8;
|
if (ns == ns_uint64) return 8;
|
||||||
if (ns == ns_float32 || ns == ns_float) return 11;
|
if (ns == ns_float32) return 11;
|
||||||
if (ns == ns_float64 || ns == ns_double) return 12;
|
if (ns == ns_float64) return 12;
|
||||||
LOGf << ns;
|
LOGf << ns;
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
@ -97,6 +97,8 @@ extern PyObject* (*PyArray_New)(PyTypeObject *, int, npy_intp const *, int, npy_
|
||||||
extern PyObject* (*PyArray_FromAny)(PyObject *, PyArrayDescr_Proxy *, int, int, int, PyObject *);
|
extern PyObject* (*PyArray_FromAny)(PyObject *, PyArrayDescr_Proxy *, int, int, int, PyObject *);
|
||||||
extern unsigned int (*PyArray_GetNDArrayCFeatureVersion)();
|
extern unsigned int (*PyArray_GetNDArrayCFeatureVersion)();
|
||||||
extern int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj);
|
extern int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj);
|
||||||
|
extern PyObject* (*PyArray_NewCopy)(PyObject *, int);
|
||||||
|
#define PyArray_Copy(obj) PyArray_NewCopy(obj, 0)
|
||||||
|
|
||||||
#define NPY_ARRAY_ALIGNED 0x0100
|
#define NPY_ARRAY_ALIGNED 0x0100
|
||||||
#define NPY_ARRAY_WRITEABLE 0x0400
|
#define NPY_ARRAY_WRITEABLE 0x0400
|
||||||
|
|
|
@ -293,21 +293,23 @@ DEF_IS(ArrayArgs, T) from_py_object(PyObject* obj) {
|
||||||
auto ptr = GET_RAW_PTR(VarHolder, obj);
|
auto ptr = GET_RAW_PTR(VarHolder, obj);
|
||||||
return move(fetch_sync({ptr}).at(0));
|
return move(fetch_sync({ptr}).at(0));
|
||||||
}
|
}
|
||||||
if (Py_TYPE(obj) != PyArray_Type) {
|
// PyArray_Type
|
||||||
PyObjHolder holder(PyArray_FROM_O(obj));
|
auto arr = (PyArray_Proxy*)obj;
|
||||||
|
if (Py_TYPE(obj) != PyArray_Type || !is_c_style(arr)) {
|
||||||
|
PyObjHolder holder(
|
||||||
|
Py_TYPE(obj) != PyArray_Type ?
|
||||||
|
PyArray_FROM_O(obj) :
|
||||||
|
PyArray_Copy(obj));
|
||||||
auto arr = (PyArray_Proxy*)holder.obj;
|
auto arr = (PyArray_Proxy*)holder.obj;
|
||||||
int64 size = PyArray_Size(arr);
|
int64 size = PyArray_Size(arr);
|
||||||
T args;
|
T args;
|
||||||
args.ptr = arr->data;
|
|
||||||
args.shape = vector<int64>(arr->dimensions, arr->dimensions+arr->nd);
|
args.shape = vector<int64>(arr->dimensions, arr->dimensions+arr->nd);
|
||||||
args.dtype = get_type_str(arr);
|
args.dtype = get_type_str(arr);
|
||||||
args.buffer.reset(new char[size]);
|
args.buffer.reset(new char[size]);
|
||||||
|
args.ptr = (void*)args.buffer.get();
|
||||||
memcpy((void*)args.buffer.get(), (void*)arr->data, size);
|
memcpy((void*)args.buffer.get(), (void*)arr->data, size);
|
||||||
return args;
|
return args;
|
||||||
}
|
}
|
||||||
// PyArray_Type
|
|
||||||
auto arr = (PyArray_Proxy*)obj;
|
|
||||||
CHECK(is_c_style(arr));
|
|
||||||
T args;
|
T args;
|
||||||
args.ptr = arr->data;
|
args.ptr = arr->data;
|
||||||
if (arr->dimensions)
|
if (arr->dimensions)
|
||||||
|
|
|
@ -97,6 +97,9 @@ ArrayArgs VarHolder::fetch_sync() {
|
||||||
return {var->mem_ptr, var->shape, var->dtype()};
|
return {var->mem_ptr, var->shape, var->dtype()};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// from fetch_op.cc
|
||||||
|
extern list<VarPtr> fetcher;
|
||||||
|
|
||||||
void sync_all(bool device_sync) {
|
void sync_all(bool device_sync) {
|
||||||
vector<Var*> vars;
|
vector<Var*> vars;
|
||||||
vars.reserve(VarHolder::hold_vars.size());
|
vars.reserve(VarHolder::hold_vars.size());
|
||||||
|
@ -104,6 +107,8 @@ void sync_all(bool device_sync) {
|
||||||
if (!v->var->_outputs.size())
|
if (!v->var->_outputs.size())
|
||||||
vars.push_back(v->var);
|
vars.push_back(v->var);
|
||||||
}
|
}
|
||||||
|
for (auto& v :fetcher)
|
||||||
|
vars.push_back(v.ptr);
|
||||||
graph_check();
|
graph_check();
|
||||||
exe.run_sync(vars, device_sync); //need sync at last
|
exe.run_sync(vars, device_sync); //need sync at last
|
||||||
graph_check();
|
graph_check();
|
||||||
|
|
|
@ -106,7 +106,7 @@ struct VarHolder {
|
||||||
/* detach the grad */
|
/* detach the grad */
|
||||||
// @pyjt(detach)
|
// @pyjt(detach)
|
||||||
inline VarHolder* detach() {
|
inline VarHolder* detach() {
|
||||||
return new VarHolder(move(jittor::detach(var)));
|
return new VarHolder(jittor::detach(var));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue