mirror of https://github.com/Jittor/Jittor
add AutoDiff
This commit is contained in:
parent
ee4dbe4c11
commit
e93810879c
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.1.7.16'
|
||||
__version__ = '1.1.7.18'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
|
@ -479,6 +479,17 @@ class Module:
|
|||
self.dfs([], None, callback, callback_leave)
|
||||
return _uniq(ps)
|
||||
|
||||
def named_parameters(self):
|
||||
ps = self.parameters()
|
||||
return [ (p.name(), p) for p in ps ]
|
||||
|
||||
def state_dict(self):
|
||||
ps = self.parameters()
|
||||
return { p.name(): p for p in ps }
|
||||
|
||||
def load_state_dict(self, params):
|
||||
self.load_parameters(params)
|
||||
|
||||
def modules(self):
|
||||
ms = []
|
||||
def callback(parents, k, v, n):
|
||||
|
@ -487,6 +498,36 @@ class Module:
|
|||
self.dfs([], None, callback, None)
|
||||
return _uniq(ms)
|
||||
|
||||
def named_modules(self):
|
||||
ms = []
|
||||
stack = []
|
||||
def callback(parents, k, v, n):
|
||||
stack.append(str(k))
|
||||
name = ".".join(stack[1:])
|
||||
ms.append((name, v))
|
||||
def callback_leave(parents, k, v, n):
|
||||
stack.pop()
|
||||
self.dfs([], "", callback, callback_leave)
|
||||
return ms
|
||||
|
||||
def register_forward_hook(self, func):
|
||||
cls = self.__class__
|
||||
self.__fhook__ = func
|
||||
if hasattr(cls, "__hooked__"):
|
||||
return
|
||||
cls.__hooked__ = True
|
||||
origin_call = cls.__call__
|
||||
def new_call(self, *args, **kw):
|
||||
ret = origin_call(self, *args, **kw)
|
||||
if hasattr(self, "__fhook__"):
|
||||
if len(kw):
|
||||
self.__fhook__(self, args, ret, kw)
|
||||
else:
|
||||
self.__fhook__(self, args, ret)
|
||||
return ret
|
||||
self.__class__.__call__ = new_call
|
||||
|
||||
|
||||
def children(self):
|
||||
cd = []
|
||||
def callback(parents, k, v, n):
|
||||
|
@ -720,6 +761,7 @@ def make_module(func, exec_n_args=1):
|
|||
return f"{func.__name__}({self.extra_repr()})"
|
||||
def extra_repr(self):
|
||||
return ",".join(map(str, self.args))
|
||||
MakeModule.__name__ = func.__name__
|
||||
return MakeModule
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Authors: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import jittor as jt
|
||||
|
||||
skip_this_test = False
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
import torchvision.models as tcmodels
|
||||
from torch import nn
|
||||
except:
|
||||
torch = None
|
||||
skip_this_test = True
|
||||
|
||||
@unittest.skipIf(skip_this_test, "skip_this_test")
|
||||
class TestAutoDiff(unittest.TestCase):
|
||||
def test_pt_hook(self):
|
||||
code = '''
|
||||
import numpy as np
|
||||
from jittor_utils import auto_diff
|
||||
import torch
|
||||
import torchvision.models as tcmodels
|
||||
net = tcmodels.resnet50()
|
||||
net.train()
|
||||
hook = auto_diff.Hook("resnet50")
|
||||
hook.hook_module(net)
|
||||
|
||||
np.random.seed(0)
|
||||
data = np.random.random((2,3,224,224)).astype('float32')
|
||||
data = torch.Tensor(data)
|
||||
net(data)
|
||||
# assert auto_diff.has_error == 0, auto_diff.has_error
|
||||
'''
|
||||
with open("/tmp/test_pt_hook.py", 'w') as f:
|
||||
f.write(code)
|
||||
assert os.system(sys.executable+" /tmp/test_pt_hook.py") == 0
|
||||
assert os.system(sys.executable+" /tmp/test_pt_hook.py") == 0
|
||||
code = '''
|
||||
import numpy as np
|
||||
import jittor as jt
|
||||
from jittor_utils import auto_diff
|
||||
from jittor.models import resnet50
|
||||
net = resnet50()
|
||||
net.train()
|
||||
hook = auto_diff.Hook("resnet50")
|
||||
hook.hook_module(net)
|
||||
|
||||
np.random.seed(0)
|
||||
data = np.random.random((2,3,224,224)).astype('float32')
|
||||
data = jt.array(data)
|
||||
net(data)
|
||||
# assert auto_diff.has_error == 0, auto_diff.has_error
|
||||
'''
|
||||
with open("/tmp/test_jt_hook.py", 'w') as f:
|
||||
f.write(code)
|
||||
assert os.system(sys.executable+" /tmp/test_jt_hook.py") == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,182 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import numpy as np
|
||||
import jittor_utils
|
||||
from jittor_utils import LOG
|
||||
|
||||
jittor_utils.try_import_jit_utils_core()
|
||||
|
||||
|
||||
has_error = 0
|
||||
|
||||
def convert(data):
|
||||
if isinstance(data, tuple):
|
||||
return tuple( convert(v) for v in data )
|
||||
if isinstance(data, list):
|
||||
return [ convert(v) for v in data ]
|
||||
if isinstance(data, np.ndarray):
|
||||
return data
|
||||
if isinstance(data, dict):
|
||||
return {k:convert(data[k]) for k in data}
|
||||
if hasattr(data, "numpy"):
|
||||
return data.detach().numpy()
|
||||
return data
|
||||
|
||||
class Hook:
|
||||
def __init__(self, base_name, rtol=5e-2, atol=1e-3):
|
||||
self.rid = 0
|
||||
self.base_name = base_name
|
||||
self.base_path = os.path.join(str(Path.home()), ".cache", "jittor", "auto_diff", base_name)
|
||||
os.makedirs(self.base_path, exist_ok=True)
|
||||
self.rtol = rtol
|
||||
self.atol = atol
|
||||
LOG.i("Use cache path:", self.base_path)
|
||||
LOG.i(f"rtol:{rtol} atol:{atol}")
|
||||
|
||||
def check_array(self, name, a, b):
|
||||
rtol = self.rtol
|
||||
atol = self.atol
|
||||
global has_error
|
||||
err = np.abs(a-b)
|
||||
tol = atol + rtol * np.abs(b)
|
||||
is_error = np.logical_or( err > tol, (a>=-1e-5)!=(b>=-1e-5))
|
||||
index = np.where(is_error)
|
||||
assert len(index)>0
|
||||
if len(index[0]) == 0:
|
||||
return
|
||||
|
||||
has_error += 1
|
||||
LOG.e(f"Ndarray <{name}> not match, shape:{a.shape}")
|
||||
i = tuple( i[0] for i in index )
|
||||
err_rate = is_error.mean()
|
||||
LOG.e(f"error index at [{i}], a({a[i]}) b({b[i]}) err({err[i]}) > tol({tol[i]}), err_rate:{err_rate*100:.3f}%")
|
||||
if err_rate > 0.01:
|
||||
LOG.e("!"*10+"Very HIGH err rate"+"!"*10)
|
||||
|
||||
def check(self, name, pre_data, data):
|
||||
global has_error
|
||||
assert type(pre_data) == type(data)
|
||||
if isinstance(pre_data, (list, tuple)):
|
||||
if len(pre_data) != len(data):
|
||||
has_error += 1
|
||||
LOG.e(f"Name <{name}> len not match, {len(pre_data)} != {len(data)}")
|
||||
n = max(len(pre_data), len(data))
|
||||
for i in range(n):
|
||||
a = pre_data[i] if i<len(pre_data) else "None"
|
||||
b = data[i] if i<len(data) else "None"
|
||||
self.check(name+f".{i}", a, b)
|
||||
elif isinstance(pre_data, np.ndarray):
|
||||
if pre_data.shape != data.shape:
|
||||
has_error += 1
|
||||
LOG.e(f"Ndarray shape <{name}> not match")
|
||||
return
|
||||
self.check_array(name, pre_data, data)
|
||||
elif isinstance(pre_data, dict):
|
||||
if len(pre_data) != len(data):
|
||||
has_error += 1
|
||||
LOG.e(f"Dict Name <{name}> len not match, {len(pre_data)} != {len(data)}")
|
||||
for k in pre_data:
|
||||
if k not in data:
|
||||
has_error += 1
|
||||
LOG.e(f"Key <{k}> not in data, Name <{name}>")
|
||||
continue
|
||||
self.check(name+f".{i}", pre_data[k], data[k])
|
||||
else:
|
||||
if pre_data != data:
|
||||
has_error += 1
|
||||
LOG.e(f"Type: {type(pre_data).__name__} Name <{name}> not match {pre_data} != {data}")
|
||||
|
||||
def record(self, name, data):
|
||||
rid = self.rid
|
||||
self.rid += 1
|
||||
fpath = os.path.join(self.base_path, f"{rid}.pkl")
|
||||
data = convert(data)
|
||||
if os.path.isfile(fpath):
|
||||
with open(fpath, 'rb') as f:
|
||||
pre_name, pre_data = pickle.load(f)
|
||||
if pre_name != name:
|
||||
global has_error
|
||||
has_error += 1
|
||||
LOG.e(f"The {rid} result name not match, {pre_name} != {name}")
|
||||
return
|
||||
LOG.i(f"check {rid}:<{name}> ...")
|
||||
self.check(name, pre_data, data)
|
||||
else:
|
||||
with open(fpath, 'wb') as f:
|
||||
pickle.dump((name, data), f)
|
||||
LOG.i(f"save {rid}:<{name}> ok")
|
||||
|
||||
def record_params(self, parameters_dict):
|
||||
global has_error
|
||||
pps = {}
|
||||
for k, v in parameters_dict.items():
|
||||
if k.endswith("num_batches_tracked"):
|
||||
continue
|
||||
pps[k] = v
|
||||
ps = { name:convert(param) for name, param in pps.items() }
|
||||
fpath = os.path.join(self.base_path, f"params.pkl")
|
||||
if os.path.isfile(fpath):
|
||||
with open(fpath, 'rb') as f:
|
||||
prev_ps = pickle.load(f)
|
||||
if len(prev_ps) != len(ps):
|
||||
has_error += 1
|
||||
LOG.e(f"Params len not match {len(prev_ps)} != {len(ps)}")
|
||||
for k in ps:
|
||||
a = ps[k]
|
||||
if k not in prev_ps:
|
||||
has_error += 1
|
||||
LOG.e(f"prev param <{k}> not found.")
|
||||
continue
|
||||
b = prev_ps[k]
|
||||
if a.shape != b.shape:
|
||||
has_error += 1
|
||||
LOG.e(f"Params <{k}> shape not match {a.shape} != {b.shape}")
|
||||
continue
|
||||
std_a, mean_a = a.std(), a.mean()
|
||||
std_b, mean_b = b.std(), b.mean()
|
||||
n = a.size
|
||||
# law of large number
|
||||
std_mean_a = (std_a+std_b)/2 / np.sqrt(n) + 1e-6
|
||||
std_std_a = (std_a+std_b)/2 / np.sqrt((n-1)/2) + 1e-6
|
||||
x = 4
|
||||
if np.abs(mean_a - mean_b) > x * std_mean_a:
|
||||
has_error += 1
|
||||
LOG.e(f"param mean not match, mean_a:{mean_a}, mean_b:{mean_b}, acceptable range:({mean_a - x * std_mean_a}, {mean_a + x * std_mean_a}) name:{k} shape:{a.shape}")
|
||||
elif np.abs(std_a - std_b) > x * std_std_a:
|
||||
has_error += 1
|
||||
LOG.e(f"param std not match, std_a:{std_a}, std_b:{std_b}, acceptable range:({std_a - x * std_std_a}, {std_a + x * std_std_a}) name:{k} shape:{a.shape}")
|
||||
else:
|
||||
LOG.i(f"check param ok: <{k}> shape:{a.shape}")
|
||||
var = pps[k]
|
||||
if hasattr(var, "copy_"):
|
||||
import torch
|
||||
var.data.copy_(torch.from_numpy(b))
|
||||
else:
|
||||
var.assign(b)
|
||||
else:
|
||||
with open(fpath, 'wb') as f:
|
||||
pickle.dump(ps, f)
|
||||
LOG.i(f"save params ok")
|
||||
|
||||
|
||||
def hook_module(self, mod):
|
||||
def forward_hook(self2, input, output):
|
||||
if "relu" not in self2.__class__.__name__.lower():
|
||||
# not test relu, because input may be inplaced
|
||||
self.record(self2.__ad_mod_name__+".input", input)
|
||||
self.record(self2.__ad_mod_name__+".output", output)
|
||||
|
||||
names = []
|
||||
for name, module in mod.named_modules():
|
||||
module.__ad_mod_name__ = name
|
||||
names.append(name)
|
||||
module.register_forward_hook(forward_hook)
|
||||
self.record_params(mod.state_dict())
|
||||
self.record("module names", names)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue