This commit is contained in:
Dun Liang 2021-07-19 11:22:00 +08:00
commit 62bbdcd7d9
2 changed files with 74 additions and 22 deletions

View File

@ -41,6 +41,8 @@ net(data)
'''
with open("/tmp/test_pt_hook.py", 'w') as f:
f.write(code)
print(jt.flags.cache_path)
os.system(f"rm -rf {jt.flags.cache_path}/../../auto_diff/resnet50")
assert os.system(sys.executable+" /tmp/test_pt_hook.py") == 0
assert os.system(sys.executable+" /tmp/test_pt_hook.py") == 0
code = '''

View File

@ -1,5 +1,6 @@
import os
from pathlib import Path
from collections import defaultdict
import pickle
import numpy as np
import jittor_utils
@ -52,7 +53,12 @@ def hook_pt_randn(*shape, device=None):
def hook_pt_normal(mean, std):
import torch
shape = tuple(mean.shape)
if hasattr(mean, 'shape'):
shape = tuple(mean.shape)
elif hasattr(std, 'shape'):
shape = tuple(std.shape)
else:
shape = (1,)
np.random.seed(0)
return torch.from_numpy(np.random.normal(size=shape).astype("float32")).to(std.device) * std + mean
@ -91,9 +97,16 @@ class Hook:
self.rid = 0
self.base_name = base_name
self.base_path = os.path.join(str(Path.home()), ".cache", "jittor", "auto_diff", base_name)
os.makedirs(self.base_path, exist_ok=True)
if not os.path.exists(self.base_path):
os.makedirs(self.base_path, exist_ok=True)
self.mode = 'save'
else:
self.mode = 'check'
self.record_status = defaultdict(int)
self.rtol = rtol
self.atol = atol
LOG.i(f"Jittor AutoDiff: [{self.mode}] mode")
LOG.i("Use cache path:", self.base_path)
LOG.i(f"rtol:{rtol} atol:{atol}")
@ -139,6 +152,10 @@ class Hook:
b = data[i] if i<len(data) else "None"
self.check(name+f".{i}", a, b)
elif isinstance(pre_data, np.ndarray):
if len(pre_data.shape) == 0:
pre_data = np.array([pre_data])
if len(data.shape) == 0:
data = np.array([data])
if pre_data.shape != data.shape:
has_error += 1
LOG.e(f"Ndarray shape <{name}> not match {pre_data.shape} != {data.shape}")
@ -167,31 +184,30 @@ class Hook:
def record(self, name, data, ex_name=""):
if os.environ.get("use_auto_diff", '1') == '0':
return
rid = self.rid
self.rid += 1
fpath = os.path.join(self.base_path, f"{rid}.pkl")
self.record_status[name] += 1
fpath = os.path.join(self.base_path, f"{name}-{self.record_status[name]}.pkl")
data = convert(data)
if os.path.isfile(fpath):
with open(fpath, 'rb') as f:
pre_name, pre_data = pickle.load(f)
if pre_name != name:
self.rid += 1
if self.mode == 'check':
if os.path.isfile(fpath):
with open(fpath, 'rb') as f:
pre_name, pre_data = pickle.load(f)
LOG.i(f"check {self.rid}:<{ex_name}{name}> ...")
self.check(ex_name+name, pre_data, data)
else:
global has_error
has_error += 1
LOG.e(f"The {rid} result name not match, {pre_name} != {name}")
self.rid -= 1
LOG.e(f"No previous result found: {name}")
return
LOG.i(f"check {rid}:<{ex_name}{name}> ...")
self.check(ex_name+name, pre_data, data)
else:
with open(fpath, 'wb') as f:
pickle.dump((name, data), f)
LOG.i(f"save {rid}:<{name}> ok")
LOG.i(f"save {self.rid}:<{name}> ok")
def record_params(self, parameters_dict):
def record_params(self, parameters_dict, mod_name=""):
if os.environ.get("use_auto_diff", '1') == '0':
return
rid = self.rid
self.rid += 1
global has_error
pps = {}
for k, v in parameters_dict.items():
@ -199,8 +215,12 @@ class Hook:
continue
pps[k] = v
ps = { name:convert(param) for name, param in pps.items() }
fpath = os.path.join(self.base_path, f"{rid}-params.pkl")
if os.path.isfile(fpath):
rec_name = f"{mod_name}_params"
rec_name = f"{rec_name}-{self.record_status[rec_name]}"
self.record_status[rec_name] += 1
fpath = os.path.join(self.base_path, rec_name+".pkl")
if self.mode == 'check':
with open(fpath, 'rb') as f:
prev_ps = pickle.load(f)
if len(prev_ps) != len(ps):
@ -253,7 +273,6 @@ class Hook:
return ret
return new_func
def hook_module(self, mod, mod_name=""):
if os.environ.get("use_auto_diff", '1') == '0':
return
@ -288,7 +307,7 @@ class Hook:
self.record(name+'.p', module.p, "["+mod_class_name+"]")
module.eval()
ps = { mod_name+k:v for k, v in mod.state_dict().items() }
self.record_params(ps)
self.record_params(ps, mod_name)
self.record("module names", names)
def hook_optimizer(self, opt, opt_name=""):
@ -329,4 +348,35 @@ class Hook:
self.record(f"{opt_name}.grads[{gid}]", p.grad)
self.record(f"{opt_name}.params[{gid}]", p)
gid += 1
opt.step = step_hook
opt.step = step_hook
def save_input(self, *data):
'''
for input, label in torch_dataloader:
hook.save_input(data)
'''
if self.mode == "save":
self.record_status["[input]"] += 1
fpath = os.path.join(self.base_path, f"__input-{self.record_status['[input]']}.pkl")
with open(fpath, 'wb') as f:
pickle.dump(convert(data), f)
LOG.i(f"save input: ok")
else:
raise RuntimeError("save_input is invalid in [check] mode")
def load_input(self):
'''
for fake_input, fake_label in jittor_dataset:
input, label = hook.load_input()
input = jt.array(input)
label = jt.array(label)
'''
if self.mode == "check":
self.record_status["[input]"] += 1
fpath = os.path.join(self.base_path, f"__input-{self.record_status['[input]']}.pkl")
with open(fpath, 'rb') as f:
data = pickle.load(f)
LOG.i(f"load input: ok")
return data
else:
raise RuntimeError("load_input is invalid in [save] mode")