mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of https://github.com/Jittor/jittor into numpy_code
This commit is contained in:
commit
ff6c59f385
|
@ -1,2 +1,28 @@
|
|||
Dockerfile
|
||||
**/publish.py
|
||||
**/publish.py
|
||||
my
|
||||
.git
|
||||
.refresh
|
||||
__pycache__
|
||||
.ipynb_checkpoints/
|
||||
.vscode/
|
||||
__res/
|
||||
perf.data
|
||||
perf.data.old
|
||||
*.swp
|
||||
*.ipynb
|
||||
*.pdf
|
||||
*.zip
|
||||
*.tgz
|
||||
test.py
|
||||
extern/mkl/mkldnn_lnx*/*
|
||||
data/
|
||||
build/
|
||||
venv/
|
||||
*.md
|
||||
!*.src.md
|
||||
!README.md
|
||||
!README.cn.md
|
||||
python/jittor.egg-info
|
||||
dist/
|
||||
!doc/source/*
|
||||
|
|
|
@ -38,12 +38,14 @@ RUN pip3 install matplotlib
|
|||
|
||||
RUN apt install openmpi-bin openmpi-common libopenmpi-dev -y
|
||||
|
||||
RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
|
||||
|
||||
RUN pip3 uninstall jittor -y
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN pip3 install . --timeout 100
|
||||
|
||||
RUN python3.7 -m jittor.test.test_example
|
||||
|
||||
RUN rm -rf ~/.cache/jittor/default
|
||||
|
||||
CMD python3.7 -m jittor.notebook --allow-root --ip=0.0.0.0
|
|
@ -24,7 +24,7 @@ CubArgReduceOp::CubArgReduceOp(Var* x, Var* offsets, NanoString op, bool keepdim
|
|||
: x(x), offsets(offsets), op(op), keepdims(keepdims) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
ASSERT(offsets->dtype()==ns_int || offsets->dtype()==ns_int32);
|
||||
ASSERT(offsets->dtype()==ns_int32);
|
||||
y = create_output(nullptr, ns_int32);
|
||||
y_key = create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ CubArgsortOp::CubArgsortOp(Var* x, Var* indexes, Var* offsets, bool descending,
|
|||
: x(x), indexes(indexes), offsets(offsets), descending(descending) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
ASSERT(offsets->dtype()==ns_int || offsets->dtype()==ns_int32);
|
||||
ASSERT(offsets->dtype()==ns_int32);
|
||||
y = create_output(nullptr, dtype);
|
||||
y_key = create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ namespace jittor {
|
|||
|
||||
struct CurandRandomOp : Op {
|
||||
Var* output;
|
||||
CurandRandomOp(NanoVector shape, NanoString dtype=ns_float);
|
||||
CurandRandomOp(NanoVector shape, NanoString dtype=ns_float32);
|
||||
|
||||
const char* name() const override { return "curand_random"; }
|
||||
DECLARE_jit_run;
|
||||
|
|
|
@ -101,6 +101,17 @@ const char *_cudaGetErrorEnum(NppStatus error);
|
|||
#endif
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
void peek(T result, char const *const func, const char *const file,
|
||||
int const line) {
|
||||
if (result) {
|
||||
// DEVICE_RESET
|
||||
LOGe << "Peek CUDA error at" << file >> ":" >> line << " code="
|
||||
>> static_cast<unsigned int>(result) >> "(" << _cudaGetErrorEnum(result) << ")"
|
||||
<< func;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void check(T result, char const *const func, const char *const file,
|
||||
int const line) {
|
||||
|
@ -116,6 +127,7 @@ void check(T result, char const *const func, const char *const file,
|
|||
// This will output the proper CUDA error strings in the event
|
||||
// that a CUDA host call returns an error
|
||||
#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__)
|
||||
#define peekCudaErrors(val) peek((val), #val, __FILE__, __LINE__)
|
||||
|
||||
// This will output the proper error string when calling cudaGetLastError
|
||||
#define getLastCudaError(msg) __getLastCudaError(msg, __FILE__, __LINE__)
|
||||
|
|
|
@ -83,7 +83,7 @@ mpi_initer() {
|
|||
MPI_CHECK(MPI_Init(NULL, NULL));
|
||||
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
|
||||
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank));
|
||||
|
||||
|
||||
//calculating localRank based on hostname which is used in selecting a GPU
|
||||
uint64_t hostHashs[mpi_world_rank];
|
||||
char hostname[1024];
|
||||
|
|
|
@ -33,7 +33,7 @@ namespace jittor {
|
|||
|
||||
struct CustomOp : Op {
|
||||
Var* output;
|
||||
CustomOp(NanoVector shape, NanoString dtype=ns_float);
|
||||
CustomOp(NanoVector shape, NanoString dtype=ns_float32);
|
||||
|
||||
const char* name() const override { return "custom"; }
|
||||
DECLARE_jit_run;
|
||||
|
|
|
@ -207,18 +207,17 @@ def liveness_info():
|
|||
}
|
||||
|
||||
def ones(shape, dtype="float32"):
|
||||
if not isinstance(shape, (NanoVector, Sequence)):
|
||||
shape = (shape,)
|
||||
return unary(1, dtype).broadcast(shape)
|
||||
|
||||
def zeros(shape, dtype="float32"):
|
||||
if not isinstance(shape, (NanoVector, Sequence)):
|
||||
shape = (shape,)
|
||||
return unary(0, dtype).broadcast(shape)
|
||||
|
||||
flags = core.flags()
|
||||
|
||||
def detach(x):
|
||||
"""return detached var"""
|
||||
return x.clone().stop_grad().clone()
|
||||
Var.detach = detach
|
||||
|
||||
def std(x):
|
||||
matsize=1
|
||||
for i in x.shape:
|
||||
|
@ -305,11 +304,11 @@ Var.masked_fill = masked_fill
|
|||
def sqr(x): return x*x
|
||||
Var.sqr = sqr
|
||||
|
||||
def argmax(x, dim:int, keepdims:bool=False):
|
||||
def argmax(x, dim, keepdims:bool=False):
|
||||
return x.arg_reduce("max", dim, keepdims)
|
||||
Var.argmax = argmax
|
||||
|
||||
def argmin(x, dim:int, keepdims:bool=False):
|
||||
def argmin(x, dim, keepdims:bool=False):
|
||||
return x.arg_reduce("min", dim, keepdims)
|
||||
Var.argmin = argmin
|
||||
|
||||
|
@ -322,13 +321,54 @@ def attrs(var):
|
|||
}
|
||||
Var.attrs = attrs
|
||||
|
||||
def fetch(vars, func, *args, **kw):
|
||||
core.fetch(vars, lambda *results: func(*results, *args, **kw))
|
||||
def fetch(*args):
|
||||
''' Async fetch vars with function closure.
|
||||
|
||||
Example 1::
|
||||
|
||||
def fetch_var(var, func, *args, **kw):
|
||||
core.fetch([var], lambda a: func(a, *args, **kw))
|
||||
Var.fetch = fetch_var
|
||||
del fetch_var
|
||||
for img,label in enumerate(your_dataset):
|
||||
pred = your_model(img)
|
||||
loss = critic(pred, label)
|
||||
acc = accuracy(pred, label)
|
||||
jt.fetch(acc, loss,
|
||||
lambda acc, loss:
|
||||
print(f"loss:{loss} acc:{acc}"
|
||||
)
|
||||
|
||||
Example 2::
|
||||
|
||||
for i,(img,label) in enumerate(your_dataset):
|
||||
pred = your_model(img)
|
||||
loss = critic(pred, label)
|
||||
acc = accuracy(pred, label)
|
||||
# variable i will be bind into function closure
|
||||
jt.fetch(i, acc, loss,
|
||||
lambda i, acc, loss:
|
||||
print(f"#{i}, loss:{loss} acc:{acc}"
|
||||
)
|
||||
'''
|
||||
assert len(args)>=1
|
||||
func = args[-1]
|
||||
assert callable(func)
|
||||
args = list(args[:-1])
|
||||
if len(args)>0 and isinstance(args[0], Sequence) \
|
||||
and len(args[0])>=1 and isinstance(args[0][0], Var):
|
||||
raise TypeError("jt.Var should not inside a list or tuple.")
|
||||
|
||||
var_map = []
|
||||
variables = []
|
||||
for i, v in enumerate(args):
|
||||
if isinstance(v, Var):
|
||||
variables.append(v)
|
||||
var_map.append(i)
|
||||
args[i] = None
|
||||
def callback(*results):
|
||||
for i,v in enumerate(results):
|
||||
args[var_map[i]] = v
|
||||
func(*args)
|
||||
core.ops.fetch(variables, callback)
|
||||
|
||||
Var.fetch = fetch
|
||||
|
||||
def display_memory_info():
|
||||
import inspect, os
|
||||
|
@ -440,11 +480,11 @@ class Module:
|
|||
end = 0
|
||||
for k in key_:
|
||||
if isinstance(v, nn.Sequential):
|
||||
if np.int(k) >= len(v.layers):
|
||||
if ori_int(k) >= len(v.layers):
|
||||
end = 1
|
||||
break
|
||||
else:
|
||||
v = v[np.int(k)]
|
||||
v = v[ori_int(k)]
|
||||
else:
|
||||
if hasattr(v, k):
|
||||
v = getattr(v, k)
|
||||
|
@ -458,12 +498,12 @@ class Module:
|
|||
else:
|
||||
LOG.v(f'load parameter {key} success ...')
|
||||
if isinstance(params[key], np.ndarray) or isinstance(params[key], list):
|
||||
v.assign(array(params[key]))
|
||||
v.update(array(params[key]))
|
||||
elif isinstance(params[key], Var):
|
||||
v.assign(params[key])
|
||||
v.update(params[key])
|
||||
else:
|
||||
# assume is pytorch tensor
|
||||
v.assign(array(params[key].cpu().detach().numpy()))
|
||||
v.update(array(params[key].cpu().detach().numpy()))
|
||||
if n_failed:
|
||||
LOG.w(f"load total {len(params)} params, {n_failed} failed")
|
||||
|
||||
|
@ -516,7 +556,7 @@ class Module:
|
|||
def mpi_param_broadcast(self, root=0):
|
||||
if not in_mpi: return
|
||||
for p in self.parameters():
|
||||
p.assign(p.mpi_broadcast(root).detach())
|
||||
p.update(p.mpi_broadcast(root))
|
||||
|
||||
def make_module(func, exec_n_args=1):
|
||||
class MakeModule(Module):
|
||||
|
@ -575,12 +615,23 @@ def jittor_exit():
|
|||
pass
|
||||
else:
|
||||
core.sync_all(True)
|
||||
core.cleanup()
|
||||
atexit.register(jittor_exit)
|
||||
|
||||
Var.__str__ = lambda x: str(x.data)
|
||||
Var.__repr__ = lambda x: str(x.data)
|
||||
Var.peek = lambda x: f"{x.dtype}{x.shape}"
|
||||
|
||||
|
||||
ori_int = int
|
||||
|
||||
int = int32
|
||||
Var.int = Var.int32
|
||||
float = float32
|
||||
Var.float = Var.float32
|
||||
double = float64
|
||||
Var.double = Var.float64
|
||||
|
||||
from . import nn
|
||||
from .nn import matmul
|
||||
from . import contrib
|
||||
|
|
|
@ -345,7 +345,7 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
|
|||
with open(os.path.join(jittor_path, header), encoding='utf8') as f:
|
||||
src = f.read()
|
||||
# XxxXxxOp(args)
|
||||
res = re.findall(pybind_attrs_reg + '('+name2+"\\([^\\n]*\\))", src, re.S)
|
||||
res = re.findall(pybind_attrs_reg + '[^~]('+name2+"\\([^\\n]*\\))", src, re.S)
|
||||
assert len(res) >= 1, "Wrong op args in " + header
|
||||
# registe op
|
||||
cc_name = os.path.join(jittor_path, header[:-2] + ".cc")
|
||||
|
@ -908,14 +908,14 @@ with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f:
|
|||
f.write(jit_src)
|
||||
cc_flags += f' -I{cache_path} '
|
||||
# gen pyjt
|
||||
pyjt_compiler.compile(cache_path, jittor_path)
|
||||
pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path)
|
||||
|
||||
# initialize order:
|
||||
# 1. registers
|
||||
# 2. generate source
|
||||
# 3. op_utils
|
||||
# 4. other
|
||||
files2 = run_cmd(f'find "{os.path.join(cache_path, "gen")}" | grep "cc$"').splitlines()
|
||||
files2 = pyjt_gen_src
|
||||
files4 = run_cmd('find -L src | grep "cc$"', jittor_path).splitlines()
|
||||
at_beginning = [
|
||||
"src/ops/op_utils.cc",
|
||||
|
|
|
@ -50,7 +50,7 @@ def slice_var_index(x, slices):
|
|||
slices = (slices,)
|
||||
if isinstance(slices[0], jt.Var):
|
||||
if len(slices) == 1 and slices[0].dtype == "bool":
|
||||
return (slices[0].where(),)
|
||||
return slice_var_index(x, tuple(slices[0].where()))
|
||||
bc = []
|
||||
ml = -1
|
||||
for idx, s in enumerate(slices):
|
||||
|
|
|
@ -29,6 +29,16 @@ def matmul_transpose(a, b):
|
|||
b = b.broadcast(shape)
|
||||
return (a*b).sum(len(shape)-1)
|
||||
|
||||
|
||||
def bmm(a, b):
|
||||
assert len(a.shape) >= 2 and len(b.shape) >= 2
|
||||
assert a.shape[-1] == b.shape[-2]
|
||||
|
||||
shape = list(a.shape) + [b.shape[-1]]
|
||||
a = a.broadcast(shape, [len(shape)-1])
|
||||
b = b.broadcast(shape, [len(shape)-3])
|
||||
return (a*b).sum(len(shape)-2)
|
||||
|
||||
def matmul(a, b):
|
||||
assert len(a.shape) >= 2 and len(b.shape) == 2
|
||||
assert a.shape[-1] == b.shape[-2]
|
||||
|
@ -191,8 +201,10 @@ class BatchNorm(Module):
|
|||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum
|
||||
self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum
|
||||
self.running_mean.update(self.running_mean +
|
||||
(xmean.reshape((-1,)) - self.running_mean) * self.momentum)
|
||||
self.running_var.update(self.running_var +
|
||||
(xvar.reshape((-1,))-self.running_var)*self.momentum)
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0,2,3])
|
||||
running_var = self.running_var.broadcast(x, [0,2,3])
|
||||
|
@ -225,8 +237,10 @@ class BatchNorm1d(Module):
|
|||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean += (xmean.sum([0])-self.running_mean)*self.momentum
|
||||
self.running_var += (xvar.sum([0])-self.running_var)*self.momentum
|
||||
self.running_mean.update(self.running_mean +
|
||||
(xmean.sum([0])-self.running_mean)*self.momentum)
|
||||
self.running_var.update(self.running_var +
|
||||
(xvar.sum([0])-self.running_var)*self.momentum)
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0])
|
||||
running_var = self.running_var.broadcast(x, [0])
|
||||
|
|
|
@ -53,9 +53,6 @@ class Optimizer(object):
|
|||
params.append(p)
|
||||
if not p.is_stop_grad():
|
||||
params_has_grad.append(p)
|
||||
|
||||
# sync params, reduce computing graph size
|
||||
jt.sync(params)
|
||||
|
||||
# get gradient
|
||||
grads = jt.grad(loss, params_has_grad)
|
||||
|
@ -75,7 +72,8 @@ class Optimizer(object):
|
|||
pg_grads = pg["grads"]
|
||||
for i, p in enumerate(pg['params']):
|
||||
if not p.is_stop_grad():
|
||||
pg_grads[i] = grads[pid]
|
||||
# stop grad of grad
|
||||
pg_grads[i] = grads[pid].stop_grad()
|
||||
pid += 1
|
||||
|
||||
def step(self, loss):
|
||||
|
@ -84,9 +82,7 @@ class Optimizer(object):
|
|||
lr = pg.get("lr", self.lr)
|
||||
for p, g in zip(pg["params"], pg["grads"]):
|
||||
if p.is_stop_grad(): continue
|
||||
p -= g * lr
|
||||
# detach with the prev graph to reduce memory consumption
|
||||
p.detach_inplace()
|
||||
p.update(p - g * lr)
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
|
@ -108,7 +104,7 @@ class SGD(Optimizer):
|
|||
for pg in self.param_groups:
|
||||
values = pg["values"] = []
|
||||
for p in pg["params"]:
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
|
||||
def step(self, loss):
|
||||
self.pre_step(loss)
|
||||
|
@ -124,12 +120,11 @@ class SGD(Optimizer):
|
|||
for p, g, v in zip(pg["params"], pg["grads"], pg["values"]):
|
||||
if p.is_stop_grad(): continue
|
||||
dp = p * weight_decay + g
|
||||
v.assign(momentum * v + dp * (1 - dampening))
|
||||
v.update(momentum * v + dp * (1 - dampening))
|
||||
if nesterov:
|
||||
p -= (dp + momentum * v) * lr
|
||||
p.update(p - (dp + momentum * v) * lr)
|
||||
else:
|
||||
p -= v * lr
|
||||
p.detach_inplace()
|
||||
p.update(p - v * lr)
|
||||
|
||||
class RMSprop(Optimizer):
|
||||
""" RMSprop Optimizer.
|
||||
|
@ -152,7 +147,7 @@ class RMSprop(Optimizer):
|
|||
for pg in self.param_groups:
|
||||
values = pg["values"] = []
|
||||
for p in pg["params"]:
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
|
||||
def step(self, loss):
|
||||
self.pre_step(loss)
|
||||
|
@ -163,9 +158,8 @@ class RMSprop(Optimizer):
|
|||
alpha = pg.get("alpha", self.alpha)
|
||||
for p, g, v in zip(pg["params"], pg["grads"], pg["values"]):
|
||||
if p.is_stop_grad(): continue
|
||||
v.assign(alpha * v + (1-alpha) * g * g)
|
||||
p -= lr * g / (jt.sqrt(v) + eps)
|
||||
p.detach_inplace()
|
||||
v.update(alpha * v + (1-alpha) * g * g)
|
||||
p.update(p - lr * g / (jt.sqrt(v) + eps))
|
||||
|
||||
class Adam(Optimizer):
|
||||
""" Adam Optimizer.
|
||||
|
@ -187,8 +181,8 @@ class Adam(Optimizer):
|
|||
values = pg["values"] = []
|
||||
m = pg["m"] = []
|
||||
for p in pg["params"]:
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
||||
m.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
m.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
|
||||
def step(self, loss):
|
||||
self.pre_step(loss)
|
||||
|
@ -200,8 +194,7 @@ class Adam(Optimizer):
|
|||
b0, b1 = pg.get("betas", self.betas)
|
||||
for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]):
|
||||
if p.is_stop_grad(): continue
|
||||
m.assign(b0 * m + (1-b0) * g)
|
||||
v.assign(b1 * v + (1-b1) * g * g)
|
||||
m.update(b0 * m + (1-b0) * g)
|
||||
v.update(b1 * v + (1-b1) * g * g)
|
||||
step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n)
|
||||
p -= m * step_size / (jt.sqrt(v) + eps)
|
||||
p.detach_inplace()
|
||||
p.update(p - m * step_size / (jt.sqrt(v) + eps))
|
||||
|
|
|
@ -258,7 +258,7 @@ def generate_error_code_from_func_header(func_head, target_scope_name, name, dfs
|
|||
|
||||
LOG.vvv("gen err from func_head", func_head)
|
||||
args = func_head[1:].split(")")[0].split(",")
|
||||
error_code = f" << \"Wrong inputs arguments, Please refer to examples(e.g. {help_cmd}).\""
|
||||
error_code = f" << \"Wrong inputs arguments, Please refer to examples({help_cmd}).\""
|
||||
error_code += r' << "\n\nTypes of your inputs are:\n"'
|
||||
for arg in args:
|
||||
arg = arg.strip()
|
||||
|
@ -849,6 +849,7 @@ def compile(cache_path, jittor_path):
|
|||
headers = [ os.path.join(jittor_path, h) for h in headers1 ] + \
|
||||
[ os.path.join(cache_path, h) for h in headers2 ]
|
||||
basenames = []
|
||||
pyjt_names = []
|
||||
for h in headers:
|
||||
with open(h, 'r') as f:
|
||||
src = f.read()
|
||||
|
@ -866,6 +867,7 @@ def compile(cache_path, jittor_path):
|
|||
if not check: continue
|
||||
|
||||
basenames.append(basename)
|
||||
pyjt_names.append(fname)
|
||||
|
||||
code = f"""
|
||||
#include "pyjt/numpy.h"
|
||||
|
@ -888,3 +890,5 @@ def compile(cache_path, jittor_path):
|
|||
LOG.vvvv(code)
|
||||
with open(fname, "w") as f:
|
||||
f.write(code)
|
||||
pyjt_names.append(fname)
|
||||
return pyjt_names
|
||||
|
|
|
@ -60,6 +60,7 @@ class TestArray(unittest.TestCase):
|
|||
for i in range(3):
|
||||
x = jt.array(im)
|
||||
b = net(x)
|
||||
b.fetch(lambda b: None)
|
||||
b.sync()
|
||||
jt.sync(device_sync=True)
|
||||
|
||||
|
@ -70,6 +71,7 @@ class TestArray(unittest.TestCase):
|
|||
x = jt.array(im)
|
||||
b = net(x)
|
||||
b.fetch(lambda b: results.append(b))
|
||||
b.sync()
|
||||
# del c
|
||||
jt.sync(device_sync=True)
|
||||
t2 = time.time() - time_start
|
||||
|
@ -111,6 +113,12 @@ class TestArray(unittest.TestCase):
|
|||
""")
|
||||
assert (b.data==[2,8,18]).all()
|
||||
|
||||
def test_not_c_style(self):
|
||||
a = np.array([1,2,3])
|
||||
b = a[::-1]
|
||||
x = jt.array(b)
|
||||
x = x + b
|
||||
assert (x.data == [6,4,2]).all()
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ def expect_error(func):
|
|||
|
||||
class TestCore(unittest.TestCase):
|
||||
def test_number_of_hold_vars(self):
|
||||
assert jt.random([1,2,3]).peek() == "float[1,2,3,]"
|
||||
assert jt.random([1,2,3]).peek() == "float32[1,2,3,]"
|
||||
assert jt.core.number_of_hold_vars() == 0
|
||||
x = jt.random([1,2,3])
|
||||
assert jt.core.number_of_hold_vars() == 1
|
||||
|
|
|
@ -16,7 +16,7 @@ namespace jittor {
|
|||
|
||||
struct CustomOp : Op {
|
||||
Var* output;
|
||||
CustomOp(NanoVector shape, NanoString dtype=ns_float);
|
||||
CustomOp(NanoVector shape, NanoString dtype=ns_float32);
|
||||
|
||||
const char* name() const override { return "custom"; }
|
||||
DECLARE_jit_run;
|
||||
|
@ -75,7 +75,7 @@ class TestCustomOp(unittest.TestCase):
|
|||
my_op = jt.compile_custom_op("""
|
||||
struct MyOp : Op {
|
||||
Var* output;
|
||||
MyOp(NanoVector shape, NanoString dtype=ns_float);
|
||||
MyOp(NanoVector shape, NanoString dtype=ns_float32);
|
||||
|
||||
const char* name() const override { return "my"; }
|
||||
DECLARE_jit_run;
|
||||
|
|
|
@ -13,7 +13,10 @@ class TestFetcher(unittest.TestCase):
|
|||
a = jt.array([1,2,3])
|
||||
a = a*2
|
||||
v = []
|
||||
jt.fetch([a], lambda a: v.append(a))
|
||||
jt.fetch(a, lambda a: v.append(a))
|
||||
jt.fetch(1, 2, 3, a,
|
||||
lambda x, y, z, a: self.assertTrue(x==1 and y==2 and z==3 and isinstance(a, np.ndarray))
|
||||
)
|
||||
jt.sync_all(True)
|
||||
assert len(v)==1 and (v[0]==[2,4,6]).all()
|
||||
|
||||
|
|
|
@ -38,8 +38,10 @@ class FakeMpiBatchNorm(nn.Module):
|
|||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum
|
||||
self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum
|
||||
self.running_mean.update(self.running_mean +
|
||||
(xmean.sum([0,2,3])-self.running_mean)*self.momentum)
|
||||
self.running_var.update(self.running_var +
|
||||
(xvar.sum([0,2,3])-self.running_var)*self.momentum)
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0,2,3])
|
||||
running_var = self.running_var.broadcast(x, [0,2,3])
|
||||
|
|
|
@ -10,7 +10,7 @@ from .test_core import expect_error
|
|||
import os
|
||||
|
||||
mid = 0
|
||||
if os.uname()[1] == "jittor-ce":
|
||||
if "jittor" in os.uname()[1]:
|
||||
mid = 1
|
||||
|
||||
class TestNanoString(unittest.TestCase):
|
||||
|
@ -27,7 +27,8 @@ class TestNanoString(unittest.TestCase):
|
|||
assert t < [1.5e-7, 1.7e-7][mid], t
|
||||
|
||||
assert (jt.hash("asdasd") == 4152566416)
|
||||
assert str(jt.NanoString("float"))=="float"
|
||||
assert str(jt.NanoString("float"))=="float32"
|
||||
assert jt.NanoString("float")=="float32"
|
||||
# pybind11: 7
|
||||
# Tuple call: 1.3
|
||||
# fast call (with or with not): 0.9
|
||||
|
@ -38,14 +39,14 @@ class TestNanoString(unittest.TestCase):
|
|||
|
||||
def test_type(self):
|
||||
import numpy as np
|
||||
assert str(jt.NanoString(float)) == "float"
|
||||
assert str(jt.NanoString(np.float)) == "float"
|
||||
assert str(jt.NanoString(float)) == "float32"
|
||||
assert str(jt.NanoString(np.float)) == "float32"
|
||||
assert str(jt.NanoString(np.float32)) == "float32"
|
||||
assert str(jt.NanoString(np.float64)) == "float64"
|
||||
assert str(jt.NanoString(np.int8)) == "int8"
|
||||
assert str(jt.NanoString(np.array([1,2,3]).dtype)) == "int64"
|
||||
|
||||
assert str(jt.NanoString(jt.float)) == "float"
|
||||
assert str(jt.NanoString(jt.float)) == "float32"
|
||||
assert str(jt.NanoString(jt.float32)) == "float32"
|
||||
assert str(jt.NanoString(jt.float64)) == "float64"
|
||||
assert str(jt.NanoString(jt.int8)) == "int8"
|
||||
|
|
|
@ -99,6 +99,7 @@ class TestResizeAndCrop(unittest.TestCase):
|
|||
test_case(20, [1024, 1024], [1.2, 1.8][mid])
|
||||
test_case(20, [1024, 666], [0.8,1.0][mid])
|
||||
|
||||
@unittest.skipIf(torch is None, "no torch found")
|
||||
def test_resize(self):
|
||||
import torch.nn.functional as F
|
||||
x = np.array(range(2*3*25)).reshape(2,3,5,5).astype("float32")
|
||||
|
@ -108,11 +109,13 @@ class TestResizeAndCrop(unittest.TestCase):
|
|||
jnn.Resize((r_size, r_size), 'bilinear', align_corners),
|
||||
lambda x: F.interpolate(x, size=(r_size, r_size), mode='bilinear',align_corners=align_corners))
|
||||
|
||||
@unittest.skipIf(torch is None, "no torch found")
|
||||
def test_upsample(self):
|
||||
arr = np.random.randn(2,3,224,224)
|
||||
check_equal(arr, jnn.Upsample(scale_factor=2), tnn.Upsample(scale_factor=2))
|
||||
check_equal(arr, jnn.Upsample(scale_factor=0.2), tnn.Upsample(scale_factor=0.2))
|
||||
|
||||
@unittest.skipIf(torch is None, "no torch found")
|
||||
def test_pixelshuffle(self):
|
||||
arr = np.random.randn(2,4,224,224)
|
||||
check_equal(arr, jnn.PixelShuffle(upscale_factor=2), tnn.PixelShuffle(upscale_factor=2))
|
||||
|
|
|
@ -44,6 +44,7 @@ class TestResnet(unittest.TestCase):
|
|||
# mnist dataset
|
||||
self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
|
||||
.set_attrs(batch_size=self.batch_size, shuffle=True)
|
||||
self.train_loader.num_workers = 4
|
||||
|
||||
# setup random seed
|
||||
def setup_seed(self, seed):
|
||||
|
@ -63,16 +64,16 @@ class TestResnet(unittest.TestCase):
|
|||
SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay)
|
||||
|
||||
for batch_idx, (data, target) in enumerate(self.train_loader):
|
||||
output = mnist_net(data)
|
||||
loss = nn.cross_entropy_loss(output, target)
|
||||
|
||||
# train step
|
||||
with jt.log_capture_scope(
|
||||
log_silent=1,
|
||||
log_v=1, log_vprefix="op.cc=100,exe=10",
|
||||
) as logs:
|
||||
output = mnist_net(data)
|
||||
loss = nn.cross_entropy_loss(output, target)
|
||||
SGD.step(loss)
|
||||
def callback(loss, output, target, batch_idx):
|
||||
def callback(batch_idx, loss, output, target):
|
||||
# print train info
|
||||
global prev
|
||||
pred = np.argmax(output, axis=1)
|
||||
|
@ -82,13 +83,13 @@ class TestResnet(unittest.TestCase):
|
|||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
|
||||
.format(0, batch_idx, 600,1. * batch_idx / 6.0, loss[0], acc, time.time()-prev))
|
||||
# prev = time.time()
|
||||
jt.fetch([loss, output, target], callback, batch_idx)
|
||||
|
||||
jt.fetch(batch_idx, loss, output, target, callback)
|
||||
|
||||
log_conv = find_log_with_re(logs,
|
||||
"Jit op key (not )?found: ((mkl)|(cudnn))_conv.*")
|
||||
log_matmul = find_log_with_re(logs,
|
||||
"Jit op key (not )?found: ((mkl)|(cublas))_matmul.*")
|
||||
if batch_idx:
|
||||
if batch_idx > 2:
|
||||
assert len(log_conv)==59 and len(log_matmul)==6, (len(log_conv), len(log_matmul))
|
||||
|
||||
mem_used = jt.flags.stat_allocator_total_alloc_byte \
|
||||
|
@ -114,12 +115,12 @@ class TestResnet(unittest.TestCase):
|
|||
# Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000
|
||||
|
||||
if jt.in_mpi:
|
||||
assert jt.core.number_of_lived_vars() < 3900, jt.core.number_of_lived_vars()
|
||||
assert jt.core.number_of_lived_vars() < 7500, jt.core.number_of_lived_vars()
|
||||
else:
|
||||
assert jt.core.number_of_lived_vars() < 3500, jt.core.number_of_lived_vars()
|
||||
assert jt.core.number_of_lived_vars() < 6500, jt.core.number_of_lived_vars()
|
||||
|
||||
jt.sync_all(True)
|
||||
assert np.mean(loss_list[-50:])<0.3
|
||||
assert np.mean(loss_list[-50:])<0.5
|
||||
assert np.mean(acc_list[-50:])>0.8
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -25,7 +25,9 @@ class TestStopFuse(unittest.TestCase):
|
|||
jt.sync(dbs+[a])
|
||||
|
||||
for a in report[1:]:
|
||||
assert len(a[0].split("opkey")) < 50
|
||||
# origin is 50
|
||||
# after update queue, increase to 102
|
||||
assert len(a[0].split("opkey")) < 110, len(a[0].split("opkey"))
|
||||
|
||||
def test_stop_fuse2(self):
|
||||
with jt.profile_scope() as report:
|
||||
|
@ -43,7 +45,9 @@ class TestStopFuse(unittest.TestCase):
|
|||
jt.sync(dbs+[a])
|
||||
|
||||
for a in report[1:]:
|
||||
assert len(a[0].split("opkey")) < 8
|
||||
# origin is 8
|
||||
# after update queue, increase to 12
|
||||
assert len(a[0].split("opkey")) < 16, len(a[0].split("opkey"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -77,7 +77,7 @@ class TestVGGClass(unittest.TestCase):
|
|||
acc_list.append(acc)
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f}'
|
||||
.format(0, batch_idx, 100,1. * batch_idx, loss[0], acc))
|
||||
jt.fetch([loss, output, target], callback, batch_idx)
|
||||
jt.fetch(batch_idx, loss, output, target, callback)
|
||||
|
||||
log_conv = find_log_with_re(logs,
|
||||
"Jit op key (not )?found: ((mkl)|(cudnn))_conv.*")
|
||||
|
|
|
@ -25,8 +25,8 @@ def docker_task(name, build_cmd):
|
|||
run_cmd(build_cmd)
|
||||
run_cmd(f"sudo docker push {name}")
|
||||
bname = os.path.basename(name)
|
||||
run_cmd(f"docker save {name}:latest -o /tmp/{bname}.tgz && chmod 666 /tmp/{bname}.tgz")
|
||||
upload_file(f" /tmp/{bname}.tgz")
|
||||
run_cmd(f"sudo docker save {name}:latest -o /tmp/{bname}.tgz && sudo chmod 666 /tmp/{bname}.tgz")
|
||||
upload_file(f"/tmp/{bname}.tgz")
|
||||
|
||||
docker_task(
|
||||
"jittor/jittor",
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright 2013 Benedikt Morbach <moben@exherbo.org>
|
||||
# Distributed under the terms of the GNU General Public License v2
|
||||
|
||||
# runs multiple MPI processes as a grid in a new tmux window and multiplexes keyboard input to all of them
|
||||
|
||||
additional_vars=( LD_LIBRARY_PATH LD_PRELOAD )
|
||||
export "${additional_vars[@]}"
|
||||
|
||||
usage() {
|
||||
echo 'tmpi: Run multiple MPI processes as a grid in a new tmux window and multiplex keyboard input to all of them.'
|
||||
echo ''
|
||||
echo 'Usage:'
|
||||
echo ' tmpi [number] [command]'
|
||||
echo ''
|
||||
echo 'You need to pass at least two arguments.'
|
||||
echo 'The first argument is the number of processes to use, every argument after that is the commandline to run.'
|
||||
echo 'If you call this script from outside tmux and your command contains important whitespace then you need to appy two levels of quoting to preserve it.'
|
||||
echo ''
|
||||
echo 'LD_LIBRARY_PATH and LD_PRELOAD are passed through, so you can run it like this:'
|
||||
echo 'LD_LIBRARY_PATH="${PWD}/.libs:${LD_LIBRARY_PATH}" tmpi 16 gdb -q bin/.libs/example'
|
||||
echo ''
|
||||
echo 'The new window is set to remain on exit and has to be closed manually. ("C-b + k" by default)'
|
||||
}
|
||||
|
||||
check_tools() {
|
||||
tools=( tmux mpirun )
|
||||
|
||||
for tool in "${tools[@]}"; do
|
||||
if ! which ${tool}; then
|
||||
echo "You need to install ${tool} to run this script."
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
if [[ ${#} -lt 2 ]]; then
|
||||
usage
|
||||
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ -z ${TMUX} ]]; then
|
||||
# it seems we aren't in a tmux session.
|
||||
# start a new one so that our window doesn't end up in some other session and we have to search it.
|
||||
# actually start a new server with '-L' to ensure that our environment carries over.
|
||||
socket=$(mktemp --dry-run tmpi.XXXX)
|
||||
exec tmux -L ${socket} new-session "${0} ${*}"
|
||||
fi
|
||||
|
||||
if [[ ${1} == runmpi ]] ; then
|
||||
# we are being started as one of many processes by mpirun.
|
||||
shift
|
||||
|
||||
# start the processes in the order of their rank.
|
||||
# this avoids races, as we have to push the variables in tmux' environment.
|
||||
# it has the nice side-effect that the panes are also ordered by rank.
|
||||
while [[ $(cat /tmp/tmpi.lock) -ne ${OMPI_COMM_WORLD_RANK} ]] ; do
|
||||
sleep 0.02
|
||||
done
|
||||
|
||||
# get all the variables that mpirun starts us with so that we can pass them through.
|
||||
mpi_vars=( $( env | grep -e MPI -e OPAL -e PMIX -e PYTHON -e debug | cut -d '=' -f1 ) )
|
||||
mpi_vars+=( "${additional_vars[@]}" )
|
||||
|
||||
# add the variables to tmux' session environment.
|
||||
# we can't just export them because the process will be started as a child of tmux, not us.
|
||||
for var in "${mpi_vars[@]}"; do
|
||||
tmux set-environment -t ${session} "${var}" "${!var}"
|
||||
done
|
||||
|
||||
x=( $(tmux split-window -P -F '#{pane_pid} #{pane_id}' -t ${window} "${*}") )
|
||||
pid=${x[0]}
|
||||
pane=${x[1]}
|
||||
|
||||
for var in "${mpi_vars[@]}"; do
|
||||
tmux set-environment -t ${session} -u "${var}"
|
||||
done
|
||||
|
||||
# kill the dummy pane that opened the new window
|
||||
[[ ${OMPI_COMM_WORLD_RANK} -eq 0 ]] && tmux kill-pane -t ${dummy} &> /dev/null
|
||||
|
||||
# set the window to tiled mode.
|
||||
# have to do this after every new pane is spawned because otherwise the splits get
|
||||
# smaller and smaller until tmux refuses to open new panes, despite plenty of space being left.
|
||||
tmux select-layout -t ${pane} tiled &> /dev/null
|
||||
|
||||
# let the next process start
|
||||
echo $((${OMPI_COMM_WORLD_RANK}+1)) > /tmp/tmpi.lock
|
||||
|
||||
# don't exit here as mpirun needs to be kept alive and it would also exit.
|
||||
while [[ -d /proc/${pid} ]]; do
|
||||
sleep 1
|
||||
done
|
||||
else
|
||||
# we are the parent and set everything up before we start ourselves a bunch of times via mpirun.
|
||||
processes=${1}
|
||||
self=${0}
|
||||
shift
|
||||
|
||||
# create an empty new dummy window which we sill later split up for the mpi processes.
|
||||
x=( $(tmux new-window ${session} -P -F '#{pane_id} #{window_id} #{session_id}') )
|
||||
export dummy=${x[0]}
|
||||
export window=${x[1]}
|
||||
export session=${x[2]}
|
||||
|
||||
# syncronize input to all panes.
|
||||
tmux set-window-option -t ${window} synchronize-panes on &> /dev/null
|
||||
tmux set-window-option -t ${window} remain-on-exit on &> /dev/null
|
||||
|
||||
# always start with rank 0.
|
||||
echo 0 > /tmp/tmpi.lock
|
||||
|
||||
# re-execute ourself to spawn of the processes.
|
||||
echo mpirun -np ${processes} ${self} runmpi "${@}"
|
||||
mpirun -np ${processes} ${self} runmpi "${@}"
|
||||
fi
|
2
setup.py
2
setup.py
|
@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh:
|
|||
|
||||
setuptools.setup(
|
||||
name='jittor',
|
||||
version='1.1.4.9',
|
||||
version='1.1.5.2',
|
||||
# scripts=[],
|
||||
author="Jittor Group",
|
||||
author_email="ran.donglang@gmail.com",
|
||||
|
|
|
@ -9,7 +9,6 @@
|
|||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "mem/allocator/cuda_dual_allocator.h"
|
||||
#include "fetcher.h"
|
||||
#include "event_queue.h"
|
||||
#endif
|
||||
#include "misc/cuda_flags.h"
|
||||
|
@ -26,6 +25,9 @@ namespace jittor {
|
|||
|
||||
Executor exe;
|
||||
|
||||
// from fetch_op.cc
|
||||
extern list<VarPtr> fetcher_to_free;
|
||||
|
||||
void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
||||
auto allocator = get_allocator();
|
||||
this->allocator = allocator;
|
||||
|
@ -33,22 +35,43 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
int op_num = 0;
|
||||
vector<Node*> bfs_q;
|
||||
bfs_q.reserve(vars.size());
|
||||
auto nodes = (vector<Node*>*)&vars;
|
||||
int start_var_num = 0;
|
||||
for (Var* v : vars)
|
||||
if (!v->is_finished())
|
||||
start_var_num++;
|
||||
bfs_backward(*nodes, bfs_q, [&](Node *node) -> bool {
|
||||
node->custom_data = 0;
|
||||
if (node->is_finished())
|
||||
return false;
|
||||
op_num += !node->is_var();
|
||||
return true;
|
||||
});
|
||||
{
|
||||
// get all nodes need to be executed
|
||||
auto t = ++Node::tflag_count;
|
||||
for (Var* v : vars)
|
||||
if (!v->is_finished() && v->tflag != t) {
|
||||
v->tflag = t;
|
||||
start_var_num++;
|
||||
bfs_q.push_back(v);
|
||||
}
|
||||
for (int i=0; i<bfs_q.size(); i++) {
|
||||
auto node = bfs_q[i];
|
||||
op_num += !node->is_var();
|
||||
for (auto i : node->_inputs)
|
||||
if (i.node->tflag != t && !i.node->is_finished()) {
|
||||
i.node->tflag = t;
|
||||
bfs_q.push_back(i.node);
|
||||
}
|
||||
// this var has been fetched
|
||||
if (node->flags.get(NodeFlags::_fetch)) {
|
||||
for (auto& n : node->_outputs) {
|
||||
// if not in queue and is fetch op
|
||||
if (n.node->tflag != t &&
|
||||
!n.node->is_finished() &&
|
||||
n.node->flags.get(NodeFlags::_fetch)) {
|
||||
n.node->tflag = t;
|
||||
bfs_q.push_back(n.node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
auto tt = Node::tflag_count;
|
||||
vector<Op*> ops;
|
||||
vector<Var*> all_vars;
|
||||
ops.reserve(op_num);
|
||||
all_vars.reserve(bfs_q.size() - op_num);
|
||||
for (Node* node : bfs_q)
|
||||
if (!node->is_var()) {
|
||||
node->custom_data = ops.size();
|
||||
|
@ -105,7 +128,8 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
// var_fused represents:
|
||||
// 0: can fused
|
||||
// 1: cannot fused
|
||||
// 2: can shared
|
||||
// 2: weak shared(may turn into 1 or 3 by shared operator cutting)
|
||||
// 3: strong shared(force shared)
|
||||
vector<int> roots, next(op_num, -1);
|
||||
vector<int> deps(op_num, 0);
|
||||
roots.reserve(op_num);
|
||||
|
@ -176,6 +200,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
sharegraph_q.reserve(16);
|
||||
vector<int> shared_id(op_num, -1);
|
||||
|
||||
// for fused op in reversed order
|
||||
for (uint rid=0; rid<queue.size(); rid++) {
|
||||
int root = queue[queue.size()-rid-1];
|
||||
auto& queue = subgraph;
|
||||
|
@ -193,10 +218,13 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
if (fopid == root)
|
||||
deps[i]++;
|
||||
else if (shared_id[opid] != root) {
|
||||
auto& vf = var_fused[v->custom_data];
|
||||
// var_fused = 1 cannot share input op
|
||||
// TODO: check this input op's output var all can be shared
|
||||
if (var_fused[v->custom_data] == 1)
|
||||
if (vf == 1)
|
||||
continue;
|
||||
// if weak share, turn into strong share
|
||||
if (vf == 2) vf = 3;
|
||||
// new shared op
|
||||
deps[opid] = 0;
|
||||
shared_id[opid] = root;
|
||||
|
@ -216,6 +244,15 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
int vi = v->custom_data;
|
||||
if (var_fused[vi] == 1)
|
||||
continue;
|
||||
// if weak share, cut off
|
||||
if (var_fused[vi] == 2) {
|
||||
if (sharegraph.size() - sn < 32)
|
||||
var_fused[vi] = 3;
|
||||
else {
|
||||
var_fused[vi] = 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Op* opi = v->input();
|
||||
int opid = opi->custom_data;
|
||||
int& dep = deps[opid];
|
||||
|
@ -377,7 +414,6 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
outputs_bk.push_back(var);
|
||||
op->finish_pending_liveness();
|
||||
for (Var* var : outputs_bk)
|
||||
// var->finish_pending_liveness();
|
||||
var->finish_pending_liveness();
|
||||
} catch (const std::exception& e) {
|
||||
// log memory info
|
||||
|
@ -396,6 +432,8 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
}
|
||||
LOGvv << "All" << op_num << "ops finished, return vars:" << vars;
|
||||
for (Var* v : vars) ASSERT(v->mem_ptr);
|
||||
// clean fetcher free buffer
|
||||
fetcher_to_free.clear();
|
||||
#ifdef HAS_CUDA
|
||||
if (device_sync && use_cuda) {
|
||||
last_is_cuda = false;
|
||||
|
|
18
src/fuser.cc
18
src/fuser.cc
|
@ -174,10 +174,20 @@ void count_fuse(int64_t tt, int start_var_num, const vector<Op*>& ops, const vec
|
|||
var_fused[i]=1;
|
||||
}
|
||||
}
|
||||
if (vf==0) var_fused[i]=1;
|
||||
if (var_fused[i] && vf &&
|
||||
(iop->type()==OpType::broadcast || all_reduce || v->flags.get(NodeFlags::_force_fuse)))
|
||||
var_fused[i]=2;
|
||||
if (vf==0)
|
||||
// cannot fused
|
||||
var_fused[i]=1;
|
||||
else if (var_fused[i]) {
|
||||
if (iop->type()==OpType::broadcast ||
|
||||
all_reduce ||
|
||||
v->flags.get(NodeFlags::_force_fuse))
|
||||
// strong fused
|
||||
var_fused[i] = 3;
|
||||
else
|
||||
// weak fused
|
||||
var_fused[i] = 2;
|
||||
// var_fused[i] = 3;
|
||||
}
|
||||
}
|
||||
// output vars can not be fused
|
||||
for (int i=0; i<start_var_num; i++)
|
||||
|
|
|
@ -24,7 +24,10 @@ VarPtr make_grad(Op* op, Var* out, Var* dout, Var* x, int x_index) {
|
|||
if (dout == nullptr) return nullptr;
|
||||
LOGvvvv << "Make grad op:" >> op->name() << "inputs:" >> op->inputs()
|
||||
<< "out:" >> out << "dout:" >> dout << "x:" >> x << "xid:" >> x_index;
|
||||
return op->grad(out, dout, x, x_index);
|
||||
auto dx = op->grad(out, dout, x, x_index);
|
||||
if (x->loop_options)
|
||||
dx->loop_options = x->loop_options;
|
||||
return dx;
|
||||
}
|
||||
|
||||
inline static void assign_attrs(Var* a, Var* b) {
|
||||
|
|
10
src/init.cc
10
src/init.cc
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include "init.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "var.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -21,6 +22,15 @@ unique_ptr<std::default_random_engine> eng;
|
|||
vector<set_seed_callback> callbacks;
|
||||
int current_seed;
|
||||
|
||||
// fron fetch_op.cc
|
||||
extern list<VarPtr> fetcher;
|
||||
extern list<VarPtr> fetcher_to_free;
|
||||
|
||||
void cleanup() {
|
||||
fetcher_to_free.clear();
|
||||
fetcher.clear();
|
||||
}
|
||||
|
||||
static void init_cuda_devices() {
|
||||
#ifdef HAS_CUDA
|
||||
int count=0;
|
||||
|
|
|
@ -20,4 +20,8 @@ void add_set_seed_callback(set_seed_callback callback);
|
|||
extern "C"
|
||||
std::default_random_engine* get_random_engine();
|
||||
|
||||
// things need to be clean before python exit
|
||||
// @pyjt(cleanup)
|
||||
void cleanup();
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -95,7 +95,7 @@ struct DelayFree final : Allocator {
|
|||
void free(void* mem_ptr, size_t size, const size_t& allocation) override {
|
||||
using namespace cuda_dual_local;
|
||||
allocations.emplace_back(mem_ptr, allocation, size, &cuda_dual_allocator);
|
||||
checkCudaErrors(_cudaLaunchHostFunc(0, &to_free_allocation, 0));
|
||||
peekCudaErrors(_cudaLaunchHostFunc(0, &to_free_allocation, 0));
|
||||
}
|
||||
|
||||
void migrate_to_cpu(void*& mem_ptr, size_t& allocation, size_t size, Allocator* allocator) {
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "mem/allocator/sfrl_allocator.h"
|
||||
#include "mem/allocator/stat_allocator.h"
|
||||
#include "mem/mem_info.h"
|
||||
#include "update_queue.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -40,7 +41,7 @@ std::ostream& operator<<(std::ostream& os, const FloatOutput& o) {
|
|||
return os << o.suffix;
|
||||
}
|
||||
|
||||
void display_memory_info(const char* fileline) {
|
||||
void display_memory_info(const char* fileline, bool dump_var) {
|
||||
int p = 3;
|
||||
Log log(fileline, 'i', 0);
|
||||
log << "\n=== display_memory_info ===\n";
|
||||
|
@ -51,26 +52,28 @@ void display_memory_info(const char* fileline) {
|
|||
log << "hold_vars:" << VarHolder::hold_vars.size()
|
||||
<< "lived_vars:" << Var::number_of_lived_vars
|
||||
<< "lived_ops:" << Op::number_of_lived_ops >> '\n';
|
||||
log << "update queue:" << update_queue.queue.size()
|
||||
>> '/' >> update_queue.map.size() >> '\n';
|
||||
|
||||
#ifdef NODE_MEMCHECK
|
||||
// get the oldest var
|
||||
vector<Node*> queue;
|
||||
auto t = ++Node::tflag_count;
|
||||
for (auto& vh : VarHolder::hold_vars)
|
||||
if (vh->var->tflag != t) {
|
||||
vh->var->tflag = t;
|
||||
queue.push_back(vh->var);
|
||||
}
|
||||
bfs_both(queue, [](Node*){return true;});
|
||||
vector<pair<int64, Node*>> nodes;
|
||||
nodes.reserve(queue.size());
|
||||
for (auto* node : queue)
|
||||
nodes.push_back({node->__id(), node});
|
||||
std::sort(nodes.begin(), nodes.end());
|
||||
log << "list of the oldest nodes:\n";
|
||||
for (int i=0; i<10 && i<nodes.size(); i++) {
|
||||
log << "ID#" >> nodes[i].first >> ":" << nodes[i].second << "\n";
|
||||
}
|
||||
// vector<Node*> queue;
|
||||
// auto t = ++Node::tflag_count;
|
||||
// for (auto& vh : VarHolder::hold_vars)
|
||||
// if (vh->var->tflag != t) {
|
||||
// vh->var->tflag = t;
|
||||
// queue.push_back(vh->var);
|
||||
// }
|
||||
// bfs_both(queue, [](Node*){return true;});
|
||||
// vector<pair<int64, Node*>> nodes;
|
||||
// nodes.reserve(queue.size());
|
||||
// for (auto* node : queue)
|
||||
// nodes.push_back({node->__id(), node});
|
||||
// std::sort(nodes.begin(), nodes.end());
|
||||
// log << "list of the oldest nodes:\n";
|
||||
// for (int i=0; i<10 && i<nodes.size(); i++) {
|
||||
// log << "ID#" >> nodes[i].first >> ":" << nodes[i].second << "\n";
|
||||
// }
|
||||
#endif
|
||||
|
||||
if (use_stat_allocator) {
|
||||
|
@ -78,10 +81,15 @@ void display_memory_info(const char* fileline) {
|
|||
log << "total alloc:" << FloatOutput{(double)(stat_allocator_total_alloc_byte
|
||||
- stat_allocator_total_free_byte), " KMG", 1024, "B"};
|
||||
log << "total alloc call:" << FloatOutput{(double)(stat_allocator_total_alloc_call
|
||||
- stat_allocator_total_free_call), " KMG", 1000, ""} >> '\n';
|
||||
- stat_allocator_total_free_call), " KMG", 1000, ""}
|
||||
>> '(' >> stat_allocator_total_alloc_call >> '/' >>
|
||||
stat_allocator_total_free_call >> ")\n";
|
||||
}
|
||||
int64 all_total = 0, gpu_total = 0, cpu_total = 0;
|
||||
for (auto& a : SFRLAllocator::sfrl_allocators) {
|
||||
auto total = a->used_memory + a->unused_memory;
|
||||
all_total += total;
|
||||
a->is_cuda() ? gpu_total += total : cpu_total += total;
|
||||
log << "name:" << a->name() << "is_cuda:" << a->is_cuda()
|
||||
<< "used:" << FloatOutput{(double)a->used_memory, " KMG", 1024, "B"}
|
||||
>> "(" >> std::setprecision(p) >> a->used_memory*100.0 / total >> "%)"
|
||||
|
@ -89,6 +97,47 @@ void display_memory_info(const char* fileline) {
|
|||
>> "(" >> std::setprecision(p) >> a->unused_memory*100.0 / total >> "%)"
|
||||
<< "total:" << FloatOutput{(double)total, " KMG", 1024, "B"} >> "\n";
|
||||
}
|
||||
log << "cpu&gpu:" << FloatOutput{(double)all_total, " KMG", 1024, "B"}
|
||||
<< "gpu:" << FloatOutput{(double)gpu_total, " KMG", 1024, "B"}
|
||||
<< "cpu:" << FloatOutput{(double)cpu_total, " KMG", 1024, "B"} >> '\n';
|
||||
|
||||
if (dump_var) {
|
||||
vector<Node*> queue;
|
||||
unordered_set<Node*> visited;
|
||||
for (auto& vh : VarHolder::hold_vars)
|
||||
if (!visited.count(vh->var)) {
|
||||
queue.push_back(vh->var);
|
||||
visited.insert(vh->var);
|
||||
}
|
||||
int64 cum = 0;
|
||||
for (int i=0; i<queue.size(); i++) {
|
||||
for (auto* n : queue[i]->inputs())
|
||||
if (!visited.count(n)) {
|
||||
queue.push_back(n);
|
||||
visited.insert(n);
|
||||
}
|
||||
for (auto* n : queue[i]->outputs())
|
||||
if (!visited.count(n)) {
|
||||
queue.push_back(n);
|
||||
visited.insert(n);
|
||||
}
|
||||
if (queue[i]->is_var()) {
|
||||
auto v = (Var*)queue[i];
|
||||
if (v->size>=0 && v->mem_ptr) {
|
||||
cum += v->size;
|
||||
log << FloatOutput{(double)v->size, " KMG", 1024, "B"}
|
||||
>> "(" >> std::setprecision(p) >> v->size*100.0 / all_total >> "%)"
|
||||
<< FloatOutput{(double)cum, " KMG", 1024, "B"}
|
||||
>> "(" >> std::setprecision(p) >> cum*100.0 / all_total >> "%)"
|
||||
<< v >> "\n";
|
||||
if (v->size == 100*64*112*112*4) {
|
||||
for (auto op : v->outputs())
|
||||
log << "\t" << op << '\n';
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
log >> "===========================\n";
|
||||
log.end();
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
namespace jittor {
|
||||
|
||||
// @pyjt(display_memory_info)
|
||||
void display_memory_info(const char* fileline="");
|
||||
void display_memory_info(const char* fileline="", bool dump_var=false);
|
||||
|
||||
// @pyjt(MemInfo)
|
||||
struct MemInfo {
|
||||
|
|
|
@ -9,9 +9,6 @@
|
|||
namespace jittor {
|
||||
|
||||
#define FOR_ALL_TYPES(m) \
|
||||
m(float) \
|
||||
m(double) \
|
||||
m(int) \
|
||||
m(bool) \
|
||||
m(int8) \
|
||||
m(int16) \
|
||||
|
@ -151,6 +148,10 @@ static void init_ns() {
|
|||
NanoString::__string_to_ns["sum"] = ns_add;
|
||||
NanoString::__string_to_ns["min"] = ns_minimum;
|
||||
NanoString::__string_to_ns["max"] = ns_maximum;
|
||||
NanoString::__string_to_ns["float"] = ns_float32;
|
||||
NanoString::__string_to_ns["double"] = ns_float64;
|
||||
NanoString::__string_to_ns["int"] = ns_int32;
|
||||
NanoString::__string_to_ns["uint"] = ns_uint32;
|
||||
LOGvv << "init __string_to_ns" << NanoString::__string_to_ns;
|
||||
LOGvv << "init __ns_to_string" << NanoString::__ns_to_string;
|
||||
}
|
||||
|
|
|
@ -12,9 +12,6 @@ namespace jittor {
|
|||
#define FOR_ALL_NS(m) \
|
||||
\
|
||||
m(void) \
|
||||
m(float) \
|
||||
m(double) \
|
||||
m(int) \
|
||||
m(bool) \
|
||||
m(int8) \
|
||||
m(int16) \
|
||||
|
|
|
@ -24,11 +24,14 @@ struct NodeFlags {
|
|||
_finished=1,
|
||||
// bit2: stop grad
|
||||
_stop_grad=2,
|
||||
_n=3,
|
||||
// bit3: is fetch
|
||||
_fetch=3,
|
||||
_n=4,
|
||||
|
||||
// op related flags
|
||||
// var related flags
|
||||
_force_fuse=_n+0,
|
||||
_stop_fuse=_n+1,
|
||||
_in_update_queue=_n+2,
|
||||
|
||||
// op related flags
|
||||
// bit0: support cpu
|
||||
|
|
|
@ -32,9 +32,9 @@ Init() {
|
|||
}
|
||||
~Init() {
|
||||
if (!get_device_count()) return;
|
||||
checkCudaErrors(cudaDeviceSynchronize());
|
||||
checkCudaErrors(cudaStreamDestroy(stream));
|
||||
checkCudaErrors(cudaEventDestroy(event));
|
||||
peekCudaErrors(cudaDeviceSynchronize());
|
||||
peekCudaErrors(cudaStreamDestroy(stream));
|
||||
peekCudaErrors(cudaEventDestroy(event));
|
||||
}
|
||||
} init;
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ struct ArrayOp : Op {
|
|||
Var* output;
|
||||
Allocation allocation;
|
||||
// @pybind(None)
|
||||
ArrayOp(const void* ptr, NanoVector shape, NanoString dtype=ns_float);
|
||||
ArrayOp(const void* ptr, NanoVector shape, NanoString dtype=ns_float32);
|
||||
|
||||
ArrayOp(ArrayArgs&& args);
|
||||
template<class T>
|
||||
|
|
|
@ -31,4 +31,11 @@ void CloneOp::infer_shape() {
|
|||
y->set_shape(x->shape);
|
||||
y->share_with(x);
|
||||
}
|
||||
|
||||
VarPtr detach(Var* x) {
|
||||
auto y = make_clone(x);
|
||||
y->input()->set_stop_grad();
|
||||
return y;
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -18,4 +18,7 @@ struct CloneOp : Op {
|
|||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
};
|
||||
|
||||
VarPtr detach(Var* x);
|
||||
|
||||
} // jittor
|
|
@ -1,5 +1,7 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
@ -12,8 +14,9 @@
|
|||
#include "mem/allocator/cuda_dual_allocator.h"
|
||||
#include "event_queue.h"
|
||||
#endif
|
||||
#include "fetcher.h"
|
||||
#include "ops/fetch_op.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "executor.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -49,31 +52,68 @@ Init() {
|
|||
// do not call deleter on exit
|
||||
for (auto& f : fetch_tasks)
|
||||
f.func.deleter = nullptr;
|
||||
checkCudaErrors(cudaDeviceSynchronize());
|
||||
checkCudaErrors(cudaStreamDestroy(stream));
|
||||
checkCudaErrors(cudaEventDestroy(event));
|
||||
peekCudaErrors(cudaDeviceSynchronize());
|
||||
peekCudaErrors(cudaStreamDestroy(stream));
|
||||
peekCudaErrors(cudaEventDestroy(event));
|
||||
}
|
||||
};
|
||||
} ;
|
||||
|
||||
}
|
||||
using namespace fetcher_local;
|
||||
|
||||
#endif
|
||||
|
||||
void fetch(const vector<VarHolder*>& vh, FetchFunc&& func) {
|
||||
list<VarPtr> fetcher;
|
||||
// this list will be free at each execution
|
||||
list<VarPtr> fetcher_to_free;
|
||||
|
||||
FetchOp::FetchOp(vector<Var*>&& inputs, FetchFunc&& func)
|
||||
: fetch_vars(inputs), func(move(func)) {
|
||||
#ifdef HAS_CUDA
|
||||
static Init init;
|
||||
// stream needs to be created after nccl plugin
|
||||
static Init init_fetch;
|
||||
#endif
|
||||
sync(vh);
|
||||
vector<Allocation> allocations(vh.size());
|
||||
vector<ArrayArgs> arrays(vh.size());
|
||||
VarPtr vp(0, ns_int32);
|
||||
outputs_holder.emplace_back(vp);
|
||||
fetcher.emplace_front(move(vp));
|
||||
fetcher_iter = fetcher.begin();
|
||||
bool all_finished = true;
|
||||
for (auto v : fetch_vars)
|
||||
if (!v->is_finished()) {
|
||||
all_finished = false;
|
||||
v->flags.set(NodeFlags::_stop_fuse);
|
||||
v->flags.set(NodeFlags::_fetch);
|
||||
}
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_fetch);
|
||||
flags.set(NodeFlags::_stop_grad);
|
||||
fetcher_iter->ptr->flags.set(NodeFlags::_fetch);
|
||||
// fetcher_to_free.clear();
|
||||
if (all_finished) {
|
||||
// if all finished, run immediately
|
||||
run();
|
||||
}
|
||||
// if too many fetchers are bufferd, force flush
|
||||
while (fetcher.size() > 20) {
|
||||
LOGvvvv << "too many fetchers(">>fetcher.size() >>
|
||||
") are bufferd, force flush";
|
||||
exe.run_sync({fetcher.back().ptr}, false);
|
||||
}
|
||||
}
|
||||
|
||||
void FetchOp::run() {
|
||||
vector<Allocation> allocations(fetch_vars.size());
|
||||
vector<ArrayArgs> arrays(fetch_vars.size());
|
||||
#ifdef HAS_CUDA
|
||||
bool has_cuda_memcpy = false;
|
||||
event_queue.flush();
|
||||
#endif
|
||||
for (int i=0; i<vh.size(); i++) {
|
||||
auto v = vh[i]->var;
|
||||
LOGvvvv << "fetch" << fetch_vars.size() << "vars" << fetch_vars;
|
||||
int i = 0;
|
||||
for (auto v : fetch_vars) {
|
||||
auto& allocation = allocations[i];
|
||||
|
||||
#ifdef HAS_CUDA
|
||||
if (v->allocator->is_cuda()) {
|
||||
checkCudaErrors(cudaEventRecord(event, 0));
|
||||
|
@ -98,6 +138,7 @@ void fetch(const vector<VarHolder*>& vh, FetchFunc&& func) {
|
|||
arrays[i].ptr = allocation.ptr;
|
||||
arrays[i].shape = v->shape;
|
||||
arrays[i].dtype = v->dtype();
|
||||
i++;
|
||||
}
|
||||
#ifdef HAS_CUDA
|
||||
if (has_cuda_memcpy) {
|
||||
|
@ -109,6 +150,8 @@ void fetch(const vector<VarHolder*>& vh, FetchFunc&& func) {
|
|||
FetchResult fr{move(func), move(allocations), move(arrays)};
|
||||
fr.call();
|
||||
}
|
||||
fetcher_to_free.emplace_front(move(*fetcher_iter));
|
||||
fetcher.erase(fetcher_iter);
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -5,8 +5,9 @@
|
|||
// ***************************************************************
|
||||
#pragma once
|
||||
#include <functional>
|
||||
#include "common.h"
|
||||
#include "var_holder.h"
|
||||
#include "op.h"
|
||||
#include "var.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "ops/array_op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
@ -42,7 +43,15 @@ struct FetchResult {
|
|||
inline void call() { func.callback(this); }
|
||||
};
|
||||
|
||||
// @pyjt(fetch)
|
||||
void fetch(const vector<VarHolder*>& vh, FetchFunc&& func);
|
||||
struct FetchOp final : Op {
|
||||
vector<Var*> fetch_vars;
|
||||
FetchFunc func;
|
||||
list<VarPtr>::iterator fetcher_iter;
|
||||
|
||||
} // jittor
|
||||
FetchOp(vector<Var*>&& inputs, FetchFunc&& func);
|
||||
|
||||
const char* name() const override { return "fetch"; }
|
||||
void run() override;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -16,7 +16,7 @@ static auto make_broadcast_to = get_op_info("broadcast_to")
|
|||
.get_constructor<VarPtr, Var*, Var*, NanoVector>();
|
||||
|
||||
VarPtr make_number(float number, Var* x) {
|
||||
VarPtr nums = make_array(&number, 1, ns_float);
|
||||
VarPtr nums = make_array(&number, 1, ns_float32);
|
||||
nums = make_broadcast_to(nums, x, {});
|
||||
return make_unary(nums, x->dtype());
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ namespace jittor {
|
|||
|
||||
struct RandomOp : Op {
|
||||
Var* output;
|
||||
RandomOp(NanoVector shape, NanoString dtype=ns_float);
|
||||
RandomOp(NanoVector shape, NanoString dtype=ns_float32);
|
||||
|
||||
const char* name() const override { return "random"; }
|
||||
DECLARE_jit_run;
|
||||
|
|
|
@ -22,9 +22,6 @@ static auto make_number = get_op_info("number")
|
|||
.get_constructor<VarPtr, float, Var*>();
|
||||
|
||||
static unordered_set<string> unary_ops = {
|
||||
"float",
|
||||
"double",
|
||||
"int",
|
||||
"bool",
|
||||
"int8",
|
||||
"int16",
|
||||
|
|
|
@ -228,6 +228,9 @@ void ConvTuner::forwardTune(FusedOp* fop) {
|
|||
if (!(bop->y->input() && bop->x->input() && bop->x->input()->tflag==op->tflag && bop->y->input()->tflag==op->tflag)) continue;
|
||||
if (!(bop->x->input()->type()==OpType::broadcast && bop->y->input()->type()==OpType::broadcast)) return;
|
||||
|
||||
// only support float32 currently
|
||||
if (bop->z->dtype() != ns_float32)
|
||||
continue;
|
||||
Op* ops[3] = {op, bop->x->input(), bop->y->input()};
|
||||
int ok = 0;
|
||||
LOGvvvv << "conv like op" << fop << fop->get_jit_key();
|
||||
|
|
|
@ -15,6 +15,7 @@ PyObject* (*PyArray_New)(PyTypeObject *, int, npy_intp const *, int, npy_intp co
|
|||
PyObject* (*PyArray_FromAny)(PyObject *, PyArrayDescr_Proxy *, int, int, int, PyObject *);
|
||||
unsigned int (*PyArray_GetNDArrayCFeatureVersion)();
|
||||
int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj);
|
||||
PyObject* (*PyArray_NewCopy)(PyObject *, int);
|
||||
|
||||
tmp_data_t tmp_data;
|
||||
|
||||
|
@ -30,6 +31,7 @@ void numpy_init() {
|
|||
fill(PyArray_New, 93);
|
||||
fill(PyArray_GetNDArrayCFeatureVersion, 211);
|
||||
fill(PyArray_SetBaseObject, 282);
|
||||
fill(PyArray_NewCopy, 85);
|
||||
|
||||
ASSERT(PyArray_GetNDArrayCFeatureVersion()>=7);
|
||||
}
|
||||
|
|
|
@ -76,12 +76,12 @@ inline int get_typenum(NanoString ns) {
|
|||
if (ns == ns_uint8) return 2;
|
||||
if (ns == ns_int16) return 3;
|
||||
if (ns == ns_uint16) return 4;
|
||||
if (ns == ns_int32 || ns == ns_int) return 5;
|
||||
if (ns == ns_int32) return 5;
|
||||
if (ns == ns_uint32) return 6;
|
||||
if (ns == ns_int64) return 7;
|
||||
if (ns == ns_uint64) return 8;
|
||||
if (ns == ns_float32 || ns == ns_float) return 11;
|
||||
if (ns == ns_float64 || ns == ns_double) return 12;
|
||||
if (ns == ns_float32) return 11;
|
||||
if (ns == ns_float64) return 12;
|
||||
LOGf << ns;
|
||||
return -1;
|
||||
}
|
||||
|
@ -97,6 +97,8 @@ extern PyObject* (*PyArray_New)(PyTypeObject *, int, npy_intp const *, int, npy_
|
|||
extern PyObject* (*PyArray_FromAny)(PyObject *, PyArrayDescr_Proxy *, int, int, int, PyObject *);
|
||||
extern unsigned int (*PyArray_GetNDArrayCFeatureVersion)();
|
||||
extern int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj);
|
||||
extern PyObject* (*PyArray_NewCopy)(PyObject *, int);
|
||||
#define PyArray_Copy(obj) PyArray_NewCopy(obj, 0)
|
||||
|
||||
#define NPY_ARRAY_ALIGNED 0x0100
|
||||
#define NPY_ARRAY_WRITEABLE 0x0400
|
||||
|
|
|
@ -293,8 +293,13 @@ DEF_IS(ArrayArgs, T) from_py_object(PyObject* obj) {
|
|||
auto ptr = GET_RAW_PTR(VarHolder, obj);
|
||||
return move(fetch_sync({ptr}).at(0));
|
||||
}
|
||||
if (Py_TYPE(obj) != PyArray_Type) {
|
||||
PyObjHolder holder(PyArray_FROM_O(obj));
|
||||
// PyArray_Type
|
||||
auto arr = (PyArray_Proxy*)obj;
|
||||
if (Py_TYPE(obj) != PyArray_Type || !is_c_style(arr)) {
|
||||
PyObjHolder holder(
|
||||
Py_TYPE(obj) != PyArray_Type ?
|
||||
PyArray_FROM_O(obj) :
|
||||
PyArray_Copy(obj));
|
||||
auto arr = (PyArray_Proxy*)holder.obj;
|
||||
int64 size = PyArray_Size(arr);
|
||||
T args;
|
||||
|
@ -305,9 +310,6 @@ DEF_IS(ArrayArgs, T) from_py_object(PyObject* obj) {
|
|||
memcpy((void*)args.buffer.get(), (void*)arr->data, size);
|
||||
return args;
|
||||
}
|
||||
// PyArray_Type
|
||||
auto arr = (PyArray_Proxy*)obj;
|
||||
CHECK(is_c_style(arr));
|
||||
T args;
|
||||
args.ptr = arr->data;
|
||||
if (arr->dimensions)
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "update_queue.h"
|
||||
#include "executor.h"
|
||||
#include "node.h"
|
||||
#include "var.h"
|
||||
#include "var_holder.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
/*
|
||||
|
||||
The update queue is designed to batch update parameters asynchronously.
|
||||
It maintains several queues internally.
|
||||
Each updated parameter corresponds to a queue,
|
||||
and the elements in the queue represent several updates of this parameter.
|
||||
When a parameter is updated,
|
||||
jittor internally updates the previous parameter several times
|
||||
instead of the current parameter.
|
||||
|
||||
update queue 设计用于批量异步更新参数,其内部维护了若干个队列,
|
||||
每一个被更新的参数对应了一个队列,而队列中的元素代表了这个参数
|
||||
的若干次更新。当一个参数被更新,jittor内部会批量更新若干次之前的
|
||||
参数,而不是当前参数。
|
||||
|
||||
below fig shows a async update process
|
||||
|
||||
下图演示了一个异步更新的过程:
|
||||
|
||||
first iter
|
||||
第一次迭代:
|
||||
|
||||
\ iter 0
|
||||
param
|
||||
a 0
|
||||
b 0
|
||||
c 0
|
||||
d 0
|
||||
|
||||
second iter
|
||||
第二次迭代:
|
||||
|
||||
\ iter 0 1
|
||||
params
|
||||
a 0 1
|
||||
b 0 1
|
||||
c 0 1
|
||||
d 0 1
|
||||
|
||||
third iter begin
|
||||
第三次开始时,迭代0的update被执行:
|
||||
\ iter 0 1 2
|
||||
params
|
||||
a [0]1 2
|
||||
b [0]1
|
||||
c [0]1
|
||||
d [0]1
|
||||
|
||||
third iter end
|
||||
第三次结束:
|
||||
|
||||
\ iter 0 1 2
|
||||
params
|
||||
a 1 2
|
||||
b 1 2
|
||||
c 1 2
|
||||
d 1 2
|
||||
|
||||
update_queue_auto_flush_delay: 异步多少个iter更新.
|
||||
|
||||
update queue的提出主要是为了解决统一计算图规模持续增长(lived_var不断变多)的问题,
|
||||
在 update queue 提出之前, 计算图运行是由optimizer负责的,optim.step被调用的
|
||||
时候,会自动运行还没有运行的计算图,已经运行的计算图节点会被回收,从而计算图规模可以
|
||||
在每次迭代之间保持一个常数。
|
||||
|
||||
但是如果用户并没有调用optim.step进行更新,计算图就会持续增长,比如下面两种情况:
|
||||
|
||||
* 训练 GAN 的时候,只用 SGD 运行了 generator,没有用SGD 运行 discriminator,
|
||||
discriminator 的 batch norm 参数持续不断地更新,但是一直没有运行,导致计算图
|
||||
规模持续增长。
|
||||
* 用户在 inference 的时候忘记设置 model.eval, 这时候因为没有 SGD 刷新参数,
|
||||
然后 batch norm 的参数持续不断更新,再次导致计算图规模持续增长。
|
||||
|
||||
这些细节对于用户来说过于难以理解(LD:我有时候都很晕),一个粗暴的解决方案是 jt.sync_all,
|
||||
直接强制刷新全图,把没运行的都运行了,但是这会导致显存占用过大,因为 sync_all 运行的
|
||||
拓扑顺序不优。
|
||||
|
||||
为了让用户可以不关心这些细节, 我们在参数更新的时候,使用 var.update(new_var),
|
||||
这个接口会把更新托管给 update queue, 从而不需要关心底层计算图的大小。
|
||||
|
||||
*/
|
||||
|
||||
DEFINE_FLAG(int, update_queue_auto_flush_delay, 2, "when size of a update queue is great than this value, update queue trigger auto flush(default 2).");
|
||||
|
||||
UpdateQueue update_queue;
|
||||
|
||||
void UpdateQueue::auto_flush() {
|
||||
vector<Var*> vars;
|
||||
vars.reserve(queue.size());
|
||||
for (auto& l : queue) {
|
||||
while (l.size() && l.size() >= update_queue_auto_flush_delay) {
|
||||
auto iter = l.end(); iter--;
|
||||
auto v = iter->v;
|
||||
vars.push_back(v);
|
||||
map.erase(v);
|
||||
v->flags.set(NodeFlags::_in_update_queue, 0);
|
||||
l.pop_back();
|
||||
}
|
||||
}
|
||||
LOGvv << "auto flush var size" << vars.size();
|
||||
exe.run_sync(move(vars), false);
|
||||
}
|
||||
|
||||
void UpdateQueue::push(Var* v, Var* prev) {
|
||||
if (v->flags.get(NodeFlags::_in_update_queue))
|
||||
return;
|
||||
v->flags.set(NodeFlags::_in_update_queue);
|
||||
list<list<Item>>::iterator owner;
|
||||
|
||||
if (prev->flags.get(NodeFlags::_in_update_queue)) {
|
||||
auto iter = map.find(prev);
|
||||
ASSERT(iter != map.end());
|
||||
owner = iter->second->owner;
|
||||
} else {
|
||||
queue.emplace_front();
|
||||
owner = queue.begin();
|
||||
}
|
||||
if (owner->size() >= update_queue_auto_flush_delay)
|
||||
auto_flush();
|
||||
owner->emplace_front(UpdateQueue::Item{owner, v});
|
||||
map[v] = owner->begin();
|
||||
// if total size of update queue is too big,
|
||||
// force sync all
|
||||
if (map.size() > 100000)
|
||||
sync_all();
|
||||
}
|
||||
|
||||
void UpdateQueue::pop(Var* v) {
|
||||
auto iter = map.find(v);
|
||||
iter->second->owner->erase(iter->second);
|
||||
map.erase(iter);
|
||||
}
|
||||
|
||||
} // jittor
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct UpdateQueue {
|
||||
struct Item {
|
||||
list<list<Item>>::iterator owner;
|
||||
Var* v;
|
||||
};
|
||||
list<list<Item>> queue;
|
||||
unordered_map<Var*, list<Item>::iterator> map;
|
||||
|
||||
void push(Var* v, Var* prev);
|
||||
void pop(Var* v);
|
||||
void auto_flush();
|
||||
};
|
||||
|
||||
extern UpdateQueue update_queue;
|
||||
|
||||
} // jittor
|
||||
|
|
@ -9,6 +9,7 @@
|
|||
#include "op.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "pybind/py_var_tracer.h"
|
||||
#include "update_queue.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -30,6 +31,8 @@ Var::~Var() {
|
|||
if (mem_ptr != nullptr)
|
||||
allocator->free(mem_ptr, size, allocation);
|
||||
number_of_lived_vars--;
|
||||
if (flags.get(NodeFlags::_in_update_queue))
|
||||
update_queue.pop(this);
|
||||
}
|
||||
|
||||
string Var::to_string() {
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "var.h"
|
||||
#include "executor.h"
|
||||
#include "graph.h"
|
||||
#include "update_queue.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -28,8 +29,9 @@ VarHolder::VarHolder(Var* v) : var(v) {
|
|||
var->own_both_liveness();
|
||||
}
|
||||
|
||||
VarHolder::VarHolder(VarPtr&& v) : VarHolder(v.ptr) {
|
||||
v.free_liveness();
|
||||
VarHolder::VarHolder(VarPtr&& v) {
|
||||
add_hold_vars(this);
|
||||
var = v.ptr;
|
||||
v.ptr = nullptr;
|
||||
}
|
||||
|
||||
|
@ -74,6 +76,13 @@ VarHolder* VarHolder::assign(VarHolder* v) {
|
|||
return this;
|
||||
}
|
||||
|
||||
VarHolder* VarHolder::update(VarHolder* v) {
|
||||
auto dv = jittor::detach(v->var);
|
||||
update_queue.push(dv.ptr, var);
|
||||
*this = move(dv);
|
||||
return this;
|
||||
}
|
||||
|
||||
extern Executor exe;
|
||||
|
||||
void VarHolder::sync(bool device_sync) {
|
||||
|
@ -88,6 +97,9 @@ ArrayArgs VarHolder::fetch_sync() {
|
|||
return {var->mem_ptr, var->shape, var->dtype()};
|
||||
}
|
||||
|
||||
// from fetch_op.cc
|
||||
extern list<VarPtr> fetcher;
|
||||
|
||||
void sync_all(bool device_sync) {
|
||||
vector<Var*> vars;
|
||||
vars.reserve(VarHolder::hold_vars.size());
|
||||
|
@ -95,6 +107,8 @@ void sync_all(bool device_sync) {
|
|||
if (!v->var->_outputs.size())
|
||||
vars.push_back(v->var);
|
||||
}
|
||||
for (auto& v :fetcher)
|
||||
vars.push_back(v.ptr);
|
||||
graph_check();
|
||||
exe.run_sync(vars, device_sync); //need sync at last
|
||||
graph_check();
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
namespace jittor {
|
||||
|
||||
struct VarHolder;
|
||||
VarPtr detach(Var* x);
|
||||
|
||||
struct DataView {
|
||||
VarHolder* vh;
|
||||
|
@ -42,6 +43,15 @@ struct VarHolder {
|
|||
// @attrs(return_self)
|
||||
VarHolder* assign(VarHolder* v);
|
||||
|
||||
/* update parameter and global variable,
|
||||
different from assign, it will
|
||||
stop grad between origin var and assigned var, and
|
||||
will update in the background
|
||||
*/
|
||||
// @pyjt(update)
|
||||
// @attrs(return_self)
|
||||
VarHolder* update(VarHolder* v);
|
||||
|
||||
// @pyjt(swap)
|
||||
// @attrs(return_self)
|
||||
inline VarHolder* swap(VarHolder* v) { std::swap(var, v->var); return this; };
|
||||
|
@ -93,6 +103,13 @@ struct VarHolder {
|
|||
return var->is_stop_grad();
|
||||
}
|
||||
|
||||
/* detach the grad */
|
||||
// @pyjt(detach)
|
||||
inline VarHolder* detach() {
|
||||
return new VarHolder(move(jittor::detach(var)));
|
||||
}
|
||||
|
||||
|
||||
// @pyjt(stop_fuse)
|
||||
// @attrs(return_self)
|
||||
inline VarHolder* stop_fuse() {
|
||||
|
|
Loading…
Reference in New Issue