mirror of https://github.com/Jittor/Jittor
Merge branch 'master' into doc
This commit is contained in:
commit
b43abebdbe
|
@ -879,14 +879,19 @@ class Module:
|
||||||
else:
|
else:
|
||||||
assert isinstance(v, Var), \
|
assert isinstance(v, Var), \
|
||||||
f"expect a jittor Var, but got <{v.__class__.__name__}>, key: {key}"
|
f"expect a jittor Var, but got <{v.__class__.__name__}>, key: {key}"
|
||||||
LOG.v(f'load parameter {key} success ...')
|
|
||||||
if isinstance(params[key], np.ndarray) or isinstance(params[key], list):
|
if isinstance(params[key], np.ndarray) or isinstance(params[key], list):
|
||||||
v.update(array(params[key]))
|
param = array(params[key])
|
||||||
elif isinstance(params[key], Var):
|
elif isinstance(params[key], Var):
|
||||||
v.update(params[key])
|
param = params[key]
|
||||||
else:
|
else:
|
||||||
# assume is pytorch tensor
|
# assume is pytorch tensor
|
||||||
v.update(array(params[key].cpu().detach().numpy()))
|
param = array(params[key].cpu().detach().numpy())
|
||||||
|
if param.shape == v.shape:
|
||||||
|
LOG.v(f'load parameter {key} success ...')
|
||||||
|
v.update(param)
|
||||||
|
else:
|
||||||
|
n_failed += 1
|
||||||
|
LOG.e(f'load parameter {key} failed: expect the shape of {key} to be {v.shape}, but got {param.shape}')
|
||||||
if n_failed:
|
if n_failed:
|
||||||
LOG.w(f"load total {len(params)} params, {n_failed} failed")
|
LOG.w(f"load total {len(params)} params, {n_failed} failed")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue