update auto_diff

This commit is contained in:
Dun Liang 2020-12-22 21:58:52 +08:00
parent 8f425206b9
commit 03e8253a36
2 changed files with 4 additions and 3 deletions

View File

@ -39,8 +39,9 @@ class Optimizer(object):
@property
def defaults(self):
exclude = set(("defaults", "param_groups", "n_step"))
return { k:v for k, v in self.__dict__.items() if k[0] != '_' and k not in exclude }
exclude = set(("defaults", "param_groups", "n_step", "pre_step", "step"))
return { k:v for k, v in self.__dict__.items()
if k[0] != '_' and k not in exclude and not callable(v) }
def pre_step(self, loss):
""" something should be done before step, such as calc gradients, mpi sync, and so on.

View File

@ -140,7 +140,7 @@ class Hook:
elif isinstance(pre_data, np.ndarray):
if pre_data.shape != data.shape:
has_error += 1
LOG.e(f"Ndarray shape <{name}> not match")
LOG.e(f"Ndarray shape <{name}> not match {pre_data.shape} != {data.shape}")
return
self.check_array(name, pre_data, data)
elif isinstance(pre_data, dict):