Merge branch 'master' of github.com:Jittor/jittor into win_cuda

This commit is contained in:
Dun Liang 2021-09-28 17:23:11 +08:00
commit 9293045993
7 changed files with 14 additions and 9 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.103'
__version__ = '1.2.3.105'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -804,7 +804,7 @@ class Module:
self.dfs([], None, callback, callback_leave)
return _uniq(ps)
def named_parameters(self):
def state_dict(self):
uniq_set = set()
ps = {}
stack = []
@ -827,8 +827,9 @@ class Module:
self.dfs([], None, callback, callback_leave)
return ps
def state_dict(self):
return self.named_parameters()
def named_parameters(self):
state_dict = self.state_dict()
return list(state_dict.items())
def load_state_dict(self, params):
self.load_parameters(params)
@ -1038,7 +1039,7 @@ Arguments of hook are defined as::
>>> net.save('net.pkl')
>>> net.load('net.pkl')
'''
params = self.named_parameters()
params = self.state_dict()
params_dict = {}
for k, v in params.items():
if isinstance(v, Var):

View File

@ -47,6 +47,7 @@ void NcclBroadcastOp::jit_run() {
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, ncclInt)
@if(@strcmp(@Tx,float64)==0, ncclFloat64)
@if(@strcmp(@Tx,int64)==0, ncclInt64)
@if(@strcmp(@Tx,uint8)==0, ncclUint8)
)
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();

View File

@ -60,6 +60,7 @@ void MpiBroadcastOp::jit_run() {
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, MPI_INT)
@if(@strcmp(@Tx,float64)==0 || @strcmp(@Tx,double)==0, MPI_DOUBLE)
@if(@strcmp(@Tx,int64)==0, MPI_DOUBLE_INT)
@if(@strcmp(@Tx,uint8)==0, MPI_CHAR)
)
auto* __restrict__ yp = y->ptr<Tx>();
MPI_Bcast(yp, y->num, T_MPI, root, MPI_COMM_WORLD);

View File

@ -157,7 +157,7 @@ struct NanoString {
// force_type = 1 for int, 2 for float
inline
NanoString dtype_infer(NanoString v1, NanoString v2, int force_type=0) {
NanoString dtype_infer(NanoString v1, NanoString v2, int force_type=0, NanoString op=ns_void) {
bool is_float = v1.is_float() || v2.is_float();
int dsize = std::max(v1.dsize(), v2.dsize());
if (force_type == 1)
@ -171,6 +171,8 @@ NanoString dtype_infer(NanoString v1, NanoString v2, int force_type=0) {
if (dsize==8) return ns_int64;
if (dsize==4) return ns_int32;
if (dsize==2) return ns_int16;
if (op.data == ns_add.data || op.data == ns_subtract.data)
return ns_int8;
return v1;
}
}

View File

@ -424,7 +424,7 @@ NanoString binary_dtype_infer(NanoString op, Var* x, Var* y) {
int force_type=0;
if (op == ns_divide) force_type=2; // force float
if (op == ns_floor_divide) force_type=1; // force int
return op.is_bool() ? ns_bool : dtype_infer(x->ns, y->ns, force_type);
return op.is_bool() ? ns_bool : dtype_infer(x->ns, y->ns, force_type, op);
}
BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) {

View File

@ -93,8 +93,8 @@ class TestCore(unittest.TestCase):
def __init__(self):
self.conv1 = jt.nn.Conv(3,3,3)
net = Net()
assert list(net.named_parameters().keys()) == ['conv1.weight', 'conv1.bias']
assert list(net.conv1.named_parameters().keys()) == ['weight', 'bias']
assert list(net.state_dict().keys()) == ['conv1.weight', 'conv1.bias']
assert list(net.conv1.state_dict().keys()) == ['weight', 'bias']
pkl_name = os.path.join(jt.flags.cache_path, "sub.pkl")
net.conv1.save(pkl_name)
net.conv1.load(pkl_name)

Binary file not shown.