mirror of https://github.com/Jittor/Jittor
Merge branch 'auto_diff' of https://github.com/lzhengning/jittor
This commit is contained in:
commit
62bbdcd7d9
|
@ -41,6 +41,8 @@ net(data)
|
|||
'''
|
||||
with open("/tmp/test_pt_hook.py", 'w') as f:
|
||||
f.write(code)
|
||||
print(jt.flags.cache_path)
|
||||
os.system(f"rm -rf {jt.flags.cache_path}/../../auto_diff/resnet50")
|
||||
assert os.system(sys.executable+" /tmp/test_pt_hook.py") == 0
|
||||
assert os.system(sys.executable+" /tmp/test_pt_hook.py") == 0
|
||||
code = '''
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import pickle
|
||||
import numpy as np
|
||||
import jittor_utils
|
||||
|
@ -52,7 +53,12 @@ def hook_pt_randn(*shape, device=None):
|
|||
|
||||
def hook_pt_normal(mean, std):
|
||||
import torch
|
||||
shape = tuple(mean.shape)
|
||||
if hasattr(mean, 'shape'):
|
||||
shape = tuple(mean.shape)
|
||||
elif hasattr(std, 'shape'):
|
||||
shape = tuple(std.shape)
|
||||
else:
|
||||
shape = (1,)
|
||||
|
||||
np.random.seed(0)
|
||||
return torch.from_numpy(np.random.normal(size=shape).astype("float32")).to(std.device) * std + mean
|
||||
|
@ -91,9 +97,16 @@ class Hook:
|
|||
self.rid = 0
|
||||
self.base_name = base_name
|
||||
self.base_path = os.path.join(str(Path.home()), ".cache", "jittor", "auto_diff", base_name)
|
||||
os.makedirs(self.base_path, exist_ok=True)
|
||||
if not os.path.exists(self.base_path):
|
||||
os.makedirs(self.base_path, exist_ok=True)
|
||||
self.mode = 'save'
|
||||
else:
|
||||
self.mode = 'check'
|
||||
|
||||
self.record_status = defaultdict(int)
|
||||
self.rtol = rtol
|
||||
self.atol = atol
|
||||
LOG.i(f"Jittor AutoDiff: [{self.mode}] mode")
|
||||
LOG.i("Use cache path:", self.base_path)
|
||||
LOG.i(f"rtol:{rtol} atol:{atol}")
|
||||
|
||||
|
@ -139,6 +152,10 @@ class Hook:
|
|||
b = data[i] if i<len(data) else "None"
|
||||
self.check(name+f".{i}", a, b)
|
||||
elif isinstance(pre_data, np.ndarray):
|
||||
if len(pre_data.shape) == 0:
|
||||
pre_data = np.array([pre_data])
|
||||
if len(data.shape) == 0:
|
||||
data = np.array([data])
|
||||
if pre_data.shape != data.shape:
|
||||
has_error += 1
|
||||
LOG.e(f"Ndarray shape <{name}> not match {pre_data.shape} != {data.shape}")
|
||||
|
@ -167,31 +184,30 @@ class Hook:
|
|||
def record(self, name, data, ex_name=""):
|
||||
if os.environ.get("use_auto_diff", '1') == '0':
|
||||
return
|
||||
rid = self.rid
|
||||
self.rid += 1
|
||||
fpath = os.path.join(self.base_path, f"{rid}.pkl")
|
||||
self.record_status[name] += 1
|
||||
fpath = os.path.join(self.base_path, f"{name}-{self.record_status[name]}.pkl")
|
||||
data = convert(data)
|
||||
if os.path.isfile(fpath):
|
||||
with open(fpath, 'rb') as f:
|
||||
pre_name, pre_data = pickle.load(f)
|
||||
if pre_name != name:
|
||||
self.rid += 1
|
||||
|
||||
if self.mode == 'check':
|
||||
if os.path.isfile(fpath):
|
||||
with open(fpath, 'rb') as f:
|
||||
pre_name, pre_data = pickle.load(f)
|
||||
LOG.i(f"check {self.rid}:<{ex_name}{name}> ...")
|
||||
self.check(ex_name+name, pre_data, data)
|
||||
else:
|
||||
global has_error
|
||||
has_error += 1
|
||||
LOG.e(f"The {rid} result name not match, {pre_name} != {name}")
|
||||
self.rid -= 1
|
||||
LOG.e(f"No previous result found: {name}")
|
||||
return
|
||||
LOG.i(f"check {rid}:<{ex_name}{name}> ...")
|
||||
self.check(ex_name+name, pre_data, data)
|
||||
else:
|
||||
with open(fpath, 'wb') as f:
|
||||
pickle.dump((name, data), f)
|
||||
LOG.i(f"save {rid}:<{name}> ok")
|
||||
LOG.i(f"save {self.rid}:<{name}> ok")
|
||||
|
||||
def record_params(self, parameters_dict):
|
||||
def record_params(self, parameters_dict, mod_name=""):
|
||||
if os.environ.get("use_auto_diff", '1') == '0':
|
||||
return
|
||||
rid = self.rid
|
||||
self.rid += 1
|
||||
global has_error
|
||||
pps = {}
|
||||
for k, v in parameters_dict.items():
|
||||
|
@ -199,8 +215,12 @@ class Hook:
|
|||
continue
|
||||
pps[k] = v
|
||||
ps = { name:convert(param) for name, param in pps.items() }
|
||||
fpath = os.path.join(self.base_path, f"{rid}-params.pkl")
|
||||
if os.path.isfile(fpath):
|
||||
rec_name = f"{mod_name}_params"
|
||||
rec_name = f"{rec_name}-{self.record_status[rec_name]}"
|
||||
self.record_status[rec_name] += 1
|
||||
fpath = os.path.join(self.base_path, rec_name+".pkl")
|
||||
|
||||
if self.mode == 'check':
|
||||
with open(fpath, 'rb') as f:
|
||||
prev_ps = pickle.load(f)
|
||||
if len(prev_ps) != len(ps):
|
||||
|
@ -253,7 +273,6 @@ class Hook:
|
|||
return ret
|
||||
return new_func
|
||||
|
||||
|
||||
def hook_module(self, mod, mod_name=""):
|
||||
if os.environ.get("use_auto_diff", '1') == '0':
|
||||
return
|
||||
|
@ -288,7 +307,7 @@ class Hook:
|
|||
self.record(name+'.p', module.p, "["+mod_class_name+"]")
|
||||
module.eval()
|
||||
ps = { mod_name+k:v for k, v in mod.state_dict().items() }
|
||||
self.record_params(ps)
|
||||
self.record_params(ps, mod_name)
|
||||
self.record("module names", names)
|
||||
|
||||
def hook_optimizer(self, opt, opt_name=""):
|
||||
|
@ -329,4 +348,35 @@ class Hook:
|
|||
self.record(f"{opt_name}.grads[{gid}]", p.grad)
|
||||
self.record(f"{opt_name}.params[{gid}]", p)
|
||||
gid += 1
|
||||
opt.step = step_hook
|
||||
opt.step = step_hook
|
||||
|
||||
def save_input(self, *data):
|
||||
'''
|
||||
for input, label in torch_dataloader:
|
||||
hook.save_input(data)
|
||||
'''
|
||||
if self.mode == "save":
|
||||
self.record_status["[input]"] += 1
|
||||
fpath = os.path.join(self.base_path, f"__input-{self.record_status['[input]']}.pkl")
|
||||
with open(fpath, 'wb') as f:
|
||||
pickle.dump(convert(data), f)
|
||||
LOG.i(f"save input: ok")
|
||||
else:
|
||||
raise RuntimeError("save_input is invalid in [check] mode")
|
||||
|
||||
def load_input(self):
|
||||
'''
|
||||
for fake_input, fake_label in jittor_dataset:
|
||||
input, label = hook.load_input()
|
||||
input = jt.array(input)
|
||||
label = jt.array(label)
|
||||
'''
|
||||
if self.mode == "check":
|
||||
self.record_status["[input]"] += 1
|
||||
fpath = os.path.join(self.base_path, f"__input-{self.record_status['[input]']}.pkl")
|
||||
with open(fpath, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
LOG.i(f"load input: ok")
|
||||
return data
|
||||
else:
|
||||
raise RuntimeError("load_input is invalid in [save] mode")
|
||||
|
|
Loading…
Reference in New Issue