add AutoDiff

This commit is contained in:
Dun Liang 2020-09-19 13:44:55 +08:00
parent ee4dbe4c11
commit e93810879c
3 changed files with 293 additions and 1 deletions

View File

@ -7,7 +7,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.1.7.16'
__version__ = '1.1.7.18'
from . import lock
with lock.lock_scope():
from . import compiler
@ -479,6 +479,17 @@ class Module:
self.dfs([], None, callback, callback_leave)
return _uniq(ps)
def named_parameters(self):
ps = self.parameters()
return [ (p.name(), p) for p in ps ]
def state_dict(self):
ps = self.parameters()
return { p.name(): p for p in ps }
def load_state_dict(self, params):
self.load_parameters(params)
def modules(self):
ms = []
def callback(parents, k, v, n):
@ -487,6 +498,36 @@ class Module:
self.dfs([], None, callback, None)
return _uniq(ms)
def named_modules(self):
ms = []
stack = []
def callback(parents, k, v, n):
stack.append(str(k))
name = ".".join(stack[1:])
ms.append((name, v))
def callback_leave(parents, k, v, n):
stack.pop()
self.dfs([], "", callback, callback_leave)
return ms
def register_forward_hook(self, func):
cls = self.__class__
self.__fhook__ = func
if hasattr(cls, "__hooked__"):
return
cls.__hooked__ = True
origin_call = cls.__call__
def new_call(self, *args, **kw):
ret = origin_call(self, *args, **kw)
if hasattr(self, "__fhook__"):
if len(kw):
self.__fhook__(self, args, ret, kw)
else:
self.__fhook__(self, args, ret)
return ret
self.__class__.__call__ = new_call
def children(self):
cd = []
def callback(parents, k, v, n):
@ -720,6 +761,7 @@ def make_module(func, exec_n_args=1):
return f"{func.__name__}({self.extra_repr()})"
def extra_repr(self):
return ",".join(map(str, self.args))
MakeModule.__name__ = func.__name__
return MakeModule

View File

@ -0,0 +1,68 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Authors: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import numpy as np
import os
import sys
import jittor as jt
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
import torchvision.models as tcmodels
from torch import nn
except:
torch = None
skip_this_test = True
@unittest.skipIf(skip_this_test, "skip_this_test")
class TestAutoDiff(unittest.TestCase):
def test_pt_hook(self):
code = '''
import numpy as np
from jittor_utils import auto_diff
import torch
import torchvision.models as tcmodels
net = tcmodels.resnet50()
net.train()
hook = auto_diff.Hook("resnet50")
hook.hook_module(net)
np.random.seed(0)
data = np.random.random((2,3,224,224)).astype('float32')
data = torch.Tensor(data)
net(data)
# assert auto_diff.has_error == 0, auto_diff.has_error
'''
with open("/tmp/test_pt_hook.py", 'w') as f:
f.write(code)
assert os.system(sys.executable+" /tmp/test_pt_hook.py") == 0
assert os.system(sys.executable+" /tmp/test_pt_hook.py") == 0
code = '''
import numpy as np
import jittor as jt
from jittor_utils import auto_diff
from jittor.models import resnet50
net = resnet50()
net.train()
hook = auto_diff.Hook("resnet50")
hook.hook_module(net)
np.random.seed(0)
data = np.random.random((2,3,224,224)).astype('float32')
data = jt.array(data)
net(data)
# assert auto_diff.has_error == 0, auto_diff.has_error
'''
with open("/tmp/test_jt_hook.py", 'w') as f:
f.write(code)
assert os.system(sys.executable+" /tmp/test_jt_hook.py") == 0
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,182 @@
import os
from pathlib import Path
import pickle
import numpy as np
import jittor_utils
from jittor_utils import LOG
jittor_utils.try_import_jit_utils_core()
has_error = 0
def convert(data):
if isinstance(data, tuple):
return tuple( convert(v) for v in data )
if isinstance(data, list):
return [ convert(v) for v in data ]
if isinstance(data, np.ndarray):
return data
if isinstance(data, dict):
return {k:convert(data[k]) for k in data}
if hasattr(data, "numpy"):
return data.detach().numpy()
return data
class Hook:
def __init__(self, base_name, rtol=5e-2, atol=1e-3):
self.rid = 0
self.base_name = base_name
self.base_path = os.path.join(str(Path.home()), ".cache", "jittor", "auto_diff", base_name)
os.makedirs(self.base_path, exist_ok=True)
self.rtol = rtol
self.atol = atol
LOG.i("Use cache path:", self.base_path)
LOG.i(f"rtol:{rtol} atol:{atol}")
def check_array(self, name, a, b):
rtol = self.rtol
atol = self.atol
global has_error
err = np.abs(a-b)
tol = atol + rtol * np.abs(b)
is_error = np.logical_or( err > tol, (a>=-1e-5)!=(b>=-1e-5))
index = np.where(is_error)
assert len(index)>0
if len(index[0]) == 0:
return
has_error += 1
LOG.e(f"Ndarray <{name}> not match, shape:{a.shape}")
i = tuple( i[0] for i in index )
err_rate = is_error.mean()
LOG.e(f"error index at [{i}], a({a[i]}) b({b[i]}) err({err[i]}) > tol({tol[i]}), err_rate:{err_rate*100:.3f}%")
if err_rate > 0.01:
LOG.e("!"*10+"Very HIGH err rate"+"!"*10)
def check(self, name, pre_data, data):
global has_error
assert type(pre_data) == type(data)
if isinstance(pre_data, (list, tuple)):
if len(pre_data) != len(data):
has_error += 1
LOG.e(f"Name <{name}> len not match, {len(pre_data)} != {len(data)}")
n = max(len(pre_data), len(data))
for i in range(n):
a = pre_data[i] if i<len(pre_data) else "None"
b = data[i] if i<len(data) else "None"
self.check(name+f".{i}", a, b)
elif isinstance(pre_data, np.ndarray):
if pre_data.shape != data.shape:
has_error += 1
LOG.e(f"Ndarray shape <{name}> not match")
return
self.check_array(name, pre_data, data)
elif isinstance(pre_data, dict):
if len(pre_data) != len(data):
has_error += 1
LOG.e(f"Dict Name <{name}> len not match, {len(pre_data)} != {len(data)}")
for k in pre_data:
if k not in data:
has_error += 1
LOG.e(f"Key <{k}> not in data, Name <{name}>")
continue
self.check(name+f".{i}", pre_data[k], data[k])
else:
if pre_data != data:
has_error += 1
LOG.e(f"Type: {type(pre_data).__name__} Name <{name}> not match {pre_data} != {data}")
def record(self, name, data):
rid = self.rid
self.rid += 1
fpath = os.path.join(self.base_path, f"{rid}.pkl")
data = convert(data)
if os.path.isfile(fpath):
with open(fpath, 'rb') as f:
pre_name, pre_data = pickle.load(f)
if pre_name != name:
global has_error
has_error += 1
LOG.e(f"The {rid} result name not match, {pre_name} != {name}")
return
LOG.i(f"check {rid}:<{name}> ...")
self.check(name, pre_data, data)
else:
with open(fpath, 'wb') as f:
pickle.dump((name, data), f)
LOG.i(f"save {rid}:<{name}> ok")
def record_params(self, parameters_dict):
global has_error
pps = {}
for k, v in parameters_dict.items():
if k.endswith("num_batches_tracked"):
continue
pps[k] = v
ps = { name:convert(param) for name, param in pps.items() }
fpath = os.path.join(self.base_path, f"params.pkl")
if os.path.isfile(fpath):
with open(fpath, 'rb') as f:
prev_ps = pickle.load(f)
if len(prev_ps) != len(ps):
has_error += 1
LOG.e(f"Params len not match {len(prev_ps)} != {len(ps)}")
for k in ps:
a = ps[k]
if k not in prev_ps:
has_error += 1
LOG.e(f"prev param <{k}> not found.")
continue
b = prev_ps[k]
if a.shape != b.shape:
has_error += 1
LOG.e(f"Params <{k}> shape not match {a.shape} != {b.shape}")
continue
std_a, mean_a = a.std(), a.mean()
std_b, mean_b = b.std(), b.mean()
n = a.size
# law of large number
std_mean_a = (std_a+std_b)/2 / np.sqrt(n) + 1e-6
std_std_a = (std_a+std_b)/2 / np.sqrt((n-1)/2) + 1e-6
x = 4
if np.abs(mean_a - mean_b) > x * std_mean_a:
has_error += 1
LOG.e(f"param mean not match, mean_a:{mean_a}, mean_b:{mean_b}, acceptable range:({mean_a - x * std_mean_a}, {mean_a + x * std_mean_a}) name:{k} shape:{a.shape}")
elif np.abs(std_a - std_b) > x * std_std_a:
has_error += 1
LOG.e(f"param std not match, std_a:{std_a}, std_b:{std_b}, acceptable range:({std_a - x * std_std_a}, {std_a + x * std_std_a}) name:{k} shape:{a.shape}")
else:
LOG.i(f"check param ok: <{k}> shape:{a.shape}")
var = pps[k]
if hasattr(var, "copy_"):
import torch
var.data.copy_(torch.from_numpy(b))
else:
var.assign(b)
else:
with open(fpath, 'wb') as f:
pickle.dump(ps, f)
LOG.i(f"save params ok")
def hook_module(self, mod):
def forward_hook(self2, input, output):
if "relu" not in self2.__class__.__name__.lower():
# not test relu, because input may be inplaced
self.record(self2.__ad_mod_name__+".input", input)
self.record(self2.__ad_mod_name__+".output", output)
names = []
for name, module in mod.named_modules():
module.__ad_mod_name__ = name
names.append(name)
module.register_forward_hook(forward_hook)
self.record_params(mod.state_dict())
self.record("module names", names)