mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of https://github.com/jittor/jittor
This commit is contained in:
commit
ca7207e027
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.0.5'
|
||||
__version__ = '1.2.0.6'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
|
@ -575,6 +575,10 @@ class Module:
|
|||
ss.append(s)
|
||||
return ", ".join(ss)
|
||||
|
||||
def apply(self, func):
|
||||
for m in self.modules():
|
||||
func(m)
|
||||
|
||||
def load_parameters(self, params):
|
||||
n_failed = 0
|
||||
for key in params.keys():
|
||||
|
|
|
@ -1142,6 +1142,14 @@ class Sequential(Module):
|
|||
self.append(mod)
|
||||
def __getitem__(self, idx):
|
||||
return self.layers[idx]
|
||||
def __iter__(self):
|
||||
return self.layers.values().__iter__()
|
||||
def keys(self):
|
||||
return self.layers.keys()
|
||||
def values(self):
|
||||
return self.layers.values()
|
||||
def items(self):
|
||||
return self.layers.items()
|
||||
def execute(self, x):
|
||||
for k, layer in self.layers.items():
|
||||
x = layer(x)
|
||||
|
|
|
@ -149,6 +149,14 @@ jt.mkl_ops.mkl_conv(x, w, 1, 2).sync()
|
|||
da = jt.grad(a.sigmoid(), a)
|
||||
assert np.isnan(da.data).sum()==0, da.data
|
||||
|
||||
def test_sequential(self):
|
||||
x = jt.nn.Sequential(lambda x:x, lambda x:x)
|
||||
n = 0
|
||||
for a in x:
|
||||
n += 1
|
||||
assert n == 2
|
||||
assert list(x.keys()) == [0,1]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -133,6 +133,13 @@ class TestSlice(unittest.TestCase):
|
|||
assert np.allclose(da.numpy(), nda, atol = 1e-3)
|
||||
assert np.allclose(db.numpy(), ndb, atol = 1e-3)
|
||||
|
||||
def test_vary_shape_setitem(self):
|
||||
a = jt.array([1,2,3,4,5])
|
||||
b = jt.array([1,2,3,4,5])
|
||||
c = jt.where(b>3)
|
||||
a[c] = 0
|
||||
assert (a.data == [1,2,3,0,0]).all()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
20
setup.py
20
setup.py
|
@ -1,5 +1,16 @@
|
|||
error_msg = "Jittor only supports Ubuntu>=16.04 currently."
|
||||
error_msg = """Jittor only supports Ubuntu>=16.04 currently.
|
||||
For other OS, use Jittor may be risky.
|
||||
We strongly recommended docker installation:
|
||||
|
||||
# CPU only
|
||||
>>> docker run -it --network host jittor/jittor
|
||||
# CPU and CUDA
|
||||
>>> docker run -it --network host jittor/jittor-cuda
|
||||
|
||||
Reference:
|
||||
1. Windows/Mac/Linux通过Docker安装计图: https://cg.cs.tsinghua.edu.cn/jittor/tutorial/2020-5-15-00-00-docker/
|
||||
"""
|
||||
from warnings import warn
|
||||
try:
|
||||
with open("/etc/os-release", "r", encoding='utf8') as f:
|
||||
s = f.read().splitlines()
|
||||
|
@ -7,9 +18,10 @@ try:
|
|||
for line in s:
|
||||
a = line.split('=')
|
||||
m[a[0]] = a[1].replace("\"", "")
|
||||
except:
|
||||
raise RuntimeError(error_msg)
|
||||
assert m["NAME"] == "Ubuntu" and float(m["VERSION_ID"])>16, error_msg
|
||||
assert m["NAME"] == "Ubuntu" and float(m["VERSION_ID"])>16, error_msg
|
||||
except Exception as e:
|
||||
print(e)
|
||||
warn(error_msg)
|
||||
|
||||
import setuptools
|
||||
from setuptools import setup, find_packages
|
||||
|
|
|
@ -369,6 +369,10 @@ VarPtr GetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
auto zeros = make_number(0, v);
|
||||
// TODO: maybe add here?
|
||||
// need analysis the overlap attr os var slices
|
||||
for (int i=0; i<vs.n; i++)
|
||||
if (vs.slices[i].is_var()) {
|
||||
return make_setitem(zeros, VarSlices(vs), dout, ns_add);
|
||||
}
|
||||
return make_setitem(zeros, VarSlices(vs), dout, ns_void);
|
||||
}
|
||||
|
||||
|
|
|
@ -27,8 +27,6 @@ namespace jittor {
|
|||
|
||||
static auto make_array = get_op_info("array")
|
||||
.get_constructor<VarPtr, const void*, NanoVector, NanoString>();
|
||||
static auto make_number = get_op_info("number")
|
||||
.get_constructor<VarPtr, float, Var*>();
|
||||
static auto make_getitem = get_op_info("getitem")
|
||||
.get_constructor<VarPtr, Var*, VarSlices&&>();
|
||||
static auto make_setitem = get_op_info("setitem")
|
||||
|
@ -176,6 +174,14 @@ VarPtr SetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
}
|
||||
|
||||
void SetitemOp::jit_prepare() {
|
||||
for (int i=0; i<o_shape.size(); i++)
|
||||
if (o_shape[i]<0) {
|
||||
// because output shape is inferd, check in
|
||||
// executor not work
|
||||
// reinfer shape if o_shape has vary shape
|
||||
infer_shape();
|
||||
break;
|
||||
}
|
||||
auto data = input(1);
|
||||
add_jit_define("OP", op);
|
||||
add_jit_define("Td", data->dtype());
|
||||
|
@ -316,10 +322,18 @@ void SetitemOp::jit_run() {
|
|||
)
|
||||
auto iid = 0 @for(d, 0, IDIM, + iid@d * istride@d);
|
||||
|
||||
@if(@strcmp(@OP,void)==0,
|
||||
op[iid] = (Ti)dp[did],
|
||||
op[iid] = @expand_macro(@OP, Ti, op[iid], dp[did])
|
||||
);
|
||||
@if(@is_def(JIT_cpu),
|
||||
@if(@strcmp(@OP,void)==0,
|
||||
op[iid] = (Ti)dp[did],
|
||||
op[iid] = @expand_macro(@OP, Ti, op[iid], dp[did])
|
||||
);
|
||||
,
|
||||
@if(@strcmp(@OP,void)==0, op[iid] = (Ti)dp[did],
|
||||
@if(@strcmp(@OP,add)==0, atomicAdd(&op[iid], (Ti)dp[did]),
|
||||
op[iid] = @expand_macro(@OP, Ti, op[iid], dp[did])
|
||||
)
|
||||
);
|
||||
)
|
||||
}
|
||||
}
|
||||
#endif // JIT
|
||||
|
|
|
@ -105,7 +105,7 @@ __global__ static void where_kernel_one_warp(
|
|||
{
|
||||
index_t i@{NDIM-1} = i + tid;
|
||||
auto condid = @for(d, 0, NDIM, + i@d * condstride@d);
|
||||
uint x = i@{NDIM-1}<condshape@{NDIM-1} ? condp[condid] : 0;
|
||||
uint x = i@{NDIM-1}<condshape@{NDIM-1} ? !!condp[condid] : 0;
|
||||
uint prefix_x = prefix_sum(x, tid);
|
||||
if (x) {
|
||||
uint cn = n + prefix_x - 1;
|
||||
|
@ -142,7 +142,7 @@ __global__ static void where_kernel_one_block(
|
|||
{
|
||||
index_t i@{NDIM-1} = i + tid;
|
||||
auto condid = @for(d, 0, NDIM, + i@d * condstride@d);
|
||||
uint x = i@{NDIM-1}<condshape@{NDIM-1} ? condp[condid] : 0;
|
||||
uint x = i@{NDIM-1}<condshape@{NDIM-1} ? !!condp[condid] : 0;
|
||||
uint prefix_x = prefix_sum(x, lid);
|
||||
uint warp_sum = bc(prefix_x, 31);
|
||||
|
||||
|
|
Loading…
Reference in New Issue