This commit is contained in:
Dun Liang 2021-02-07 22:37:56 +08:00
commit 80135d4ef5
1 changed files with 9 additions and 4 deletions

View File

@ -724,14 +724,19 @@ class Module:
else:
assert isinstance(v, Var), \
f"expect a jittor Var, but got <{v.__class__.__name__}>, key: {key}"
LOG.v(f'load parameter {key} success ...')
if isinstance(params[key], np.ndarray) or isinstance(params[key], list):
v.update(array(params[key]))
param = array(params[key])
elif isinstance(params[key], Var):
v.update(params[key])
param = params[key]
else:
# assume is pytorch tensor
v.update(array(params[key].cpu().detach().numpy()))
param = array(params[key].cpu().detach().numpy())
if param.shape == v.shape:
LOG.v(f'load parameter {key} success ...')
v.update(param)
else:
n_failed += 1
LOG.e(f'load parameter {key} failed: expect the shape of {key} to be {v.shape}, but got {param.shape}')
if n_failed:
LOG.w(f"load total {len(params)} params, {n_failed} failed")