mirror of https://github.com/Jittor/Jittor
update auto_diff
This commit is contained in:
parent
8f425206b9
commit
03e8253a36
|
@ -39,8 +39,9 @@ class Optimizer(object):
|
|||
|
||||
@property
|
||||
def defaults(self):
|
||||
exclude = set(("defaults", "param_groups", "n_step"))
|
||||
return { k:v for k, v in self.__dict__.items() if k[0] != '_' and k not in exclude }
|
||||
exclude = set(("defaults", "param_groups", "n_step", "pre_step", "step"))
|
||||
return { k:v for k, v in self.__dict__.items()
|
||||
if k[0] != '_' and k not in exclude and not callable(v) }
|
||||
|
||||
def pre_step(self, loss):
|
||||
""" something should be done before step, such as calc gradients, mpi sync, and so on.
|
||||
|
|
|
@ -140,7 +140,7 @@ class Hook:
|
|||
elif isinstance(pre_data, np.ndarray):
|
||||
if pre_data.shape != data.shape:
|
||||
has_error += 1
|
||||
LOG.e(f"Ndarray shape <{name}> not match")
|
||||
LOG.e(f"Ndarray shape <{name}> not match {pre_data.shape} != {data.shape}")
|
||||
return
|
||||
self.check_array(name, pre_data, data)
|
||||
elif isinstance(pre_data, dict):
|
||||
|
|
Loading…
Reference in New Issue