mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of https://github.com/Jittor/jittor
This commit is contained in:
commit
80135d4ef5
|
@ -724,14 +724,19 @@ class Module:
|
|||
else:
|
||||
assert isinstance(v, Var), \
|
||||
f"expect a jittor Var, but got <{v.__class__.__name__}>, key: {key}"
|
||||
LOG.v(f'load parameter {key} success ...')
|
||||
if isinstance(params[key], np.ndarray) or isinstance(params[key], list):
|
||||
v.update(array(params[key]))
|
||||
param = array(params[key])
|
||||
elif isinstance(params[key], Var):
|
||||
v.update(params[key])
|
||||
param = params[key]
|
||||
else:
|
||||
# assume is pytorch tensor
|
||||
v.update(array(params[key].cpu().detach().numpy()))
|
||||
param = array(params[key].cpu().detach().numpy())
|
||||
if param.shape == v.shape:
|
||||
LOG.v(f'load parameter {key} success ...')
|
||||
v.update(param)
|
||||
else:
|
||||
n_failed += 1
|
||||
LOG.e(f'load parameter {key} failed: expect the shape of {key} to be {v.shape}, but got {param.shape}')
|
||||
if n_failed:
|
||||
LOG.w(f"load total {len(params)} params, {n_failed} failed")
|
||||
|
||||
|
|
Loading…
Reference in New Issue