This commit is contained in:
li-xl 2020-10-21 11:37:06 +08:00
commit ca7207e027
8 changed files with 70 additions and 13 deletions

View File

@ -7,7 +7,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.0.5'
__version__ = '1.2.0.6'
from . import lock
with lock.lock_scope():
from . import compiler
@ -575,6 +575,10 @@ class Module:
ss.append(s)
return ", ".join(ss)
def apply(self, func):
for m in self.modules():
func(m)
def load_parameters(self, params):
n_failed = 0
for key in params.keys():

View File

@ -1142,6 +1142,14 @@ class Sequential(Module):
self.append(mod)
def __getitem__(self, idx):
return self.layers[idx]
def __iter__(self):
return self.layers.values().__iter__()
def keys(self):
return self.layers.keys()
def values(self):
return self.layers.values()
def items(self):
return self.layers.items()
def execute(self, x):
for k, layer in self.layers.items():
x = layer(x)

View File

@ -149,6 +149,14 @@ jt.mkl_ops.mkl_conv(x, w, 1, 2).sync()
da = jt.grad(a.sigmoid(), a)
assert np.isnan(da.data).sum()==0, da.data
def test_sequential(self):
x = jt.nn.Sequential(lambda x:x, lambda x:x)
n = 0
for a in x:
n += 1
assert n == 2
assert list(x.keys()) == [0,1]
if __name__ == "__main__":
unittest.main()

View File

@ -133,6 +133,13 @@ class TestSlice(unittest.TestCase):
assert np.allclose(da.numpy(), nda, atol = 1e-3)
assert np.allclose(db.numpy(), ndb, atol = 1e-3)
def test_vary_shape_setitem(self):
a = jt.array([1,2,3,4,5])
b = jt.array([1,2,3,4,5])
c = jt.where(b>3)
a[c] = 0
assert (a.data == [1,2,3,0,0]).all()
if __name__ == "__main__":

View File

@ -1,5 +1,16 @@
error_msg = "Jittor only supports Ubuntu>=16.04 currently."
error_msg = """Jittor only supports Ubuntu>=16.04 currently.
For other OS, use Jittor may be risky.
We strongly recommended docker installation:
# CPU only
>>> docker run -it --network host jittor/jittor
# CPU and CUDA
>>> docker run -it --network host jittor/jittor-cuda
Reference:
1. Windows/Mac/Linux通过Docker安装计图: https://cg.cs.tsinghua.edu.cn/jittor/tutorial/2020-5-15-00-00-docker/
"""
from warnings import warn
try:
with open("/etc/os-release", "r", encoding='utf8') as f:
s = f.read().splitlines()
@ -7,9 +18,10 @@ try:
for line in s:
a = line.split('=')
m[a[0]] = a[1].replace("\"", "")
except:
raise RuntimeError(error_msg)
assert m["NAME"] == "Ubuntu" and float(m["VERSION_ID"])>16, error_msg
assert m["NAME"] == "Ubuntu" and float(m["VERSION_ID"])>16, error_msg
except Exception as e:
print(e)
warn(error_msg)
import setuptools
from setuptools import setup, find_packages

View File

@ -369,6 +369,10 @@ VarPtr GetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
auto zeros = make_number(0, v);
// TODO: maybe add here?
// need analysis the overlap attr os var slices
for (int i=0; i<vs.n; i++)
if (vs.slices[i].is_var()) {
return make_setitem(zeros, VarSlices(vs), dout, ns_add);
}
return make_setitem(zeros, VarSlices(vs), dout, ns_void);
}

View File

@ -27,8 +27,6 @@ namespace jittor {
static auto make_array = get_op_info("array")
.get_constructor<VarPtr, const void*, NanoVector, NanoString>();
static auto make_number = get_op_info("number")
.get_constructor<VarPtr, float, Var*>();
static auto make_getitem = get_op_info("getitem")
.get_constructor<VarPtr, Var*, VarSlices&&>();
static auto make_setitem = get_op_info("setitem")
@ -176,6 +174,14 @@ VarPtr SetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
}
void SetitemOp::jit_prepare() {
for (int i=0; i<o_shape.size(); i++)
if (o_shape[i]<0) {
// because output shape is inferd, check in
// executor not work
// reinfer shape if o_shape has vary shape
infer_shape();
break;
}
auto data = input(1);
add_jit_define("OP", op);
add_jit_define("Td", data->dtype());
@ -316,10 +322,18 @@ void SetitemOp::jit_run() {
)
auto iid = 0 @for(d, 0, IDIM, + iid@d * istride@d);
@if(@strcmp(@OP,void)==0,
op[iid] = (Ti)dp[did],
op[iid] = @expand_macro(@OP, Ti, op[iid], dp[did])
);
@if(@is_def(JIT_cpu),
@if(@strcmp(@OP,void)==0,
op[iid] = (Ti)dp[did],
op[iid] = @expand_macro(@OP, Ti, op[iid], dp[did])
);
,
@if(@strcmp(@OP,void)==0, op[iid] = (Ti)dp[did],
@if(@strcmp(@OP,add)==0, atomicAdd(&op[iid], (Ti)dp[did]),
op[iid] = @expand_macro(@OP, Ti, op[iid], dp[did])
)
);
)
}
}
#endif // JIT

View File

@ -105,7 +105,7 @@ __global__ static void where_kernel_one_warp(
{
index_t i@{NDIM-1} = i + tid;
auto condid = @for(d, 0, NDIM, + i@d * condstride@d);
uint x = i@{NDIM-1}<condshape@{NDIM-1} ? condp[condid] : 0;
uint x = i@{NDIM-1}<condshape@{NDIM-1} ? !!condp[condid] : 0;
uint prefix_x = prefix_sum(x, tid);
if (x) {
uint cn = n + prefix_x - 1;
@ -142,7 +142,7 @@ __global__ static void where_kernel_one_block(
{
index_t i@{NDIM-1} = i + tid;
auto condid = @for(d, 0, NDIM, + i@d * condstride@d);
uint x = i@{NDIM-1}<condshape@{NDIM-1} ? condp[condid] : 0;
uint x = i@{NDIM-1}<condshape@{NDIM-1} ? !!condp[condid] : 0;
uint prefix_x = prefix_sum(x, lid);
uint warp_sum = bc(prefix_x, 31);