diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index c70ed02a..9e958d1b 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -879,14 +879,19 @@ class Module: else: assert isinstance(v, Var), \ f"expect a jittor Var, but got <{v.__class__.__name__}>, key: {key}" - LOG.v(f'load parameter {key} success ...') if isinstance(params[key], np.ndarray) or isinstance(params[key], list): - v.update(array(params[key])) + param = array(params[key]) elif isinstance(params[key], Var): - v.update(params[key]) + param = params[key] else: # assume is pytorch tensor - v.update(array(params[key].cpu().detach().numpy())) + param = array(params[key].cpu().detach().numpy()) + if param.shape == v.shape: + LOG.v(f'load parameter {key} success ...') + v.update(param) + else: + n_failed += 1 + LOG.e(f'load parameter {key} failed: expect the shape of {key} to be {v.shape}, but got {param.shape}') if n_failed: LOG.w(f"load total {len(params)} params, {n_failed} failed")