From 095acdb575893962c241908599e4b90488b67c2d Mon Sep 17 00:00:00 2001 From: lzhengning Date: Fri, 5 Feb 2021 14:09:19 +0800 Subject: [PATCH 1/2] fix: do not load parameters with inconsistent shapes --- python/jittor/__init__.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index ee2a6d10..3a79aaf1 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -724,14 +724,19 @@ class Module: else: assert isinstance(v, Var), \ f"expect a jittor Var, but got <{v.__class__.__name__}>, key: {key}" - LOG.v(f'load parameter {key} success ...') if isinstance(params[key], np.ndarray) or isinstance(params[key], list): - v.update(array(params[key])) + param = array(params[key]) elif isinstance(params[key], Var): - v.update(params[key]) + param = params[key] else: # assume is pytorch tensor - v.update(array(params[key].cpu().detach().numpy())) + param = array(params[key].cpu().detach().numpy()) + if param.shape == v.shape: + LOG.v(f'load parameter {key} success ...') + v.update(param) + else: + n_failed += 1 + LOG.w(f'load parameter {key} failed: expect the shape of {key} to be {v.shape}, but got {param.shape}') if n_failed: LOG.w(f"load total {len(params)} params, {n_failed} failed") From 32549e4be847c79d0ae8cfee7714362184fac171 Mon Sep 17 00:00:00 2001 From: lzhengning Date: Fri, 5 Feb 2021 14:24:37 +0800 Subject: [PATCH 2/2] LOG: warning to error --- python/jittor/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 3a79aaf1..4ff7580e 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -736,7 +736,7 @@ class Module: v.update(param) else: n_failed += 1 - LOG.w(f'load parameter {key} failed: expect the shape of {key} to be {v.shape}, but got {param.shape}') + LOG.e(f'load parameter {key} failed: expect the shape of {key} to be {v.shape}, but got {param.shape}') if n_failed: LOG.w(f"load total {len(params)} params, {n_failed} failed")