mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of github.com:Jittor/jittor into win_cuda
This commit is contained in:
commit
9293045993
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.103'
|
||||
__version__ = '1.2.3.105'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -804,7 +804,7 @@ class Module:
|
|||
self.dfs([], None, callback, callback_leave)
|
||||
return _uniq(ps)
|
||||
|
||||
def named_parameters(self):
|
||||
def state_dict(self):
|
||||
uniq_set = set()
|
||||
ps = {}
|
||||
stack = []
|
||||
|
@ -827,8 +827,9 @@ class Module:
|
|||
self.dfs([], None, callback, callback_leave)
|
||||
return ps
|
||||
|
||||
def state_dict(self):
|
||||
return self.named_parameters()
|
||||
def named_parameters(self):
|
||||
state_dict = self.state_dict()
|
||||
return list(state_dict.items())
|
||||
|
||||
def load_state_dict(self, params):
|
||||
self.load_parameters(params)
|
||||
|
@ -1038,7 +1039,7 @@ Arguments of hook are defined as::
|
|||
>>> net.save('net.pkl')
|
||||
>>> net.load('net.pkl')
|
||||
'''
|
||||
params = self.named_parameters()
|
||||
params = self.state_dict()
|
||||
params_dict = {}
|
||||
for k, v in params.items():
|
||||
if isinstance(v, Var):
|
||||
|
|
|
@ -47,6 +47,7 @@ void NcclBroadcastOp::jit_run() {
|
|||
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, ncclInt)
|
||||
@if(@strcmp(@Tx,float64)==0, ncclFloat64)
|
||||
@if(@strcmp(@Tx,int64)==0, ncclInt64)
|
||||
@if(@strcmp(@Tx,uint8)==0, ncclUint8)
|
||||
)
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
|
|
|
@ -60,6 +60,7 @@ void MpiBroadcastOp::jit_run() {
|
|||
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, MPI_INT)
|
||||
@if(@strcmp(@Tx,float64)==0 || @strcmp(@Tx,double)==0, MPI_DOUBLE)
|
||||
@if(@strcmp(@Tx,int64)==0, MPI_DOUBLE_INT)
|
||||
@if(@strcmp(@Tx,uint8)==0, MPI_CHAR)
|
||||
)
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
MPI_Bcast(yp, y->num, T_MPI, root, MPI_COMM_WORLD);
|
||||
|
|
|
@ -157,7 +157,7 @@ struct NanoString {
|
|||
|
||||
// force_type = 1 for int, 2 for float
|
||||
inline
|
||||
NanoString dtype_infer(NanoString v1, NanoString v2, int force_type=0) {
|
||||
NanoString dtype_infer(NanoString v1, NanoString v2, int force_type=0, NanoString op=ns_void) {
|
||||
bool is_float = v1.is_float() || v2.is_float();
|
||||
int dsize = std::max(v1.dsize(), v2.dsize());
|
||||
if (force_type == 1)
|
||||
|
@ -171,6 +171,8 @@ NanoString dtype_infer(NanoString v1, NanoString v2, int force_type=0) {
|
|||
if (dsize==8) return ns_int64;
|
||||
if (dsize==4) return ns_int32;
|
||||
if (dsize==2) return ns_int16;
|
||||
if (op.data == ns_add.data || op.data == ns_subtract.data)
|
||||
return ns_int8;
|
||||
return v1;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -424,7 +424,7 @@ NanoString binary_dtype_infer(NanoString op, Var* x, Var* y) {
|
|||
int force_type=0;
|
||||
if (op == ns_divide) force_type=2; // force float
|
||||
if (op == ns_floor_divide) force_type=1; // force int
|
||||
return op.is_bool() ? ns_bool : dtype_infer(x->ns, y->ns, force_type);
|
||||
return op.is_bool() ? ns_bool : dtype_infer(x->ns, y->ns, force_type, op);
|
||||
}
|
||||
|
||||
BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) {
|
||||
|
|
|
@ -93,8 +93,8 @@ class TestCore(unittest.TestCase):
|
|||
def __init__(self):
|
||||
self.conv1 = jt.nn.Conv(3,3,3)
|
||||
net = Net()
|
||||
assert list(net.named_parameters().keys()) == ['conv1.weight', 'conv1.bias']
|
||||
assert list(net.conv1.named_parameters().keys()) == ['weight', 'bias']
|
||||
assert list(net.state_dict().keys()) == ['conv1.weight', 'conv1.bias']
|
||||
assert list(net.conv1.state_dict().keys()) == ['weight', 'bias']
|
||||
pkl_name = os.path.join(jt.flags.cache_path, "sub.pkl")
|
||||
net.conv1.save(pkl_name)
|
||||
net.conv1.load(pkl_name)
|
||||
|
|
Binary file not shown.
Loading…
Reference in New Issue