Merge branch 'master' into doc

This commit is contained in:
lzhengning 2021-02-09 18:48:36 +08:00
commit b43abebdbe
1 changed files with 9 additions and 4 deletions

View File

@ -879,14 +879,19 @@ class Module:
else: else:
assert isinstance(v, Var), \ assert isinstance(v, Var), \
f"expect a jittor Var, but got <{v.__class__.__name__}>, key: {key}" f"expect a jittor Var, but got <{v.__class__.__name__}>, key: {key}"
LOG.v(f'load parameter {key} success ...')
if isinstance(params[key], np.ndarray) or isinstance(params[key], list): if isinstance(params[key], np.ndarray) or isinstance(params[key], list):
v.update(array(params[key])) param = array(params[key])
elif isinstance(params[key], Var): elif isinstance(params[key], Var):
v.update(params[key]) param = params[key]
else: else:
# assume is pytorch tensor # assume is pytorch tensor
v.update(array(params[key].cpu().detach().numpy())) param = array(params[key].cpu().detach().numpy())
if param.shape == v.shape:
LOG.v(f'load parameter {key} success ...')
v.update(param)
else:
n_failed += 1
LOG.e(f'load parameter {key} failed: expect the shape of {key} to be {v.shape}, but got {param.shape}')
if n_failed: if n_failed:
LOG.w(f"load total {len(params)} params, {n_failed} failed") LOG.w(f"load total {len(params)} params, {n_failed} failed")