mirror of https://github.com/Jittor/Jittor
auto diff hook optimizer and rand
This commit is contained in:
parent
c07d85dade
commit
d6b9e32428
|
@ -502,9 +502,10 @@ class Module:
|
|||
ms = []
|
||||
stack = []
|
||||
def callback(parents, k, v, n):
|
||||
stack.append(str(k))
|
||||
name = ".".join(stack[1:])
|
||||
ms.append((name, v))
|
||||
if isinstance(v, Module):
|
||||
stack.append(str(k))
|
||||
name = ".".join(stack[1:])
|
||||
ms.append((name, v))
|
||||
def callback_leave(parents, k, v, n):
|
||||
stack.pop()
|
||||
self.dfs([], "", callback, callback_leave)
|
||||
|
@ -815,6 +816,28 @@ Var.__str__ = lambda x: str(x.data)
|
|||
Var.__repr__ = lambda x: str(x.data)
|
||||
Var.peek = lambda x: f"{x.dtype}{x.shape}"
|
||||
|
||||
def item(v):
|
||||
return v.data.item()
|
||||
|
||||
def to_int(v):
|
||||
dtype = str(v.dtype)
|
||||
assert dtype.startswith("int")
|
||||
return v.item()
|
||||
|
||||
def to_float(v):
|
||||
dtype = str(v.dtype)
|
||||
assert dtype.startswith("float")
|
||||
return v.item()
|
||||
|
||||
def to_bool(v):
|
||||
dtype = str(v.dtype)
|
||||
assert dtype.startswith("int") or dtype=="bool"
|
||||
return bool(v.item())
|
||||
|
||||
Var.item = item
|
||||
Var.__int__ = to_int
|
||||
Var.__float__ = to_float
|
||||
Var.__bool__ = to_bool
|
||||
|
||||
ori_int = int
|
||||
|
||||
|
|
|
@ -89,6 +89,11 @@ def chunk(x, chunks, dim=0):
|
|||
return res
|
||||
jt.Var.chunk = chunk
|
||||
|
||||
|
||||
def expand(x, shape):
|
||||
return x.broadcast(shape)
|
||||
jt.Var.expand = expand
|
||||
|
||||
def stack(x, dim=0):
|
||||
r'''
|
||||
Concatenates sequence of vars along a new dimension.
|
||||
|
|
|
@ -261,19 +261,20 @@ class Linear(Module):
|
|||
return x
|
||||
|
||||
class BatchNorm(Module):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
|
||||
assert affine == None
|
||||
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, is_train=True, sync=True):
|
||||
self.sync = sync
|
||||
self.num_features = num_features
|
||||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
|
||||
self.affine = affine
|
||||
if self.affine:
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
|
||||
def execute(self, x):
|
||||
if self.is_train:
|
||||
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
|
||||
|
@ -292,23 +293,27 @@ class BatchNorm(Module):
|
|||
running_mean = self.running_mean.broadcast(x, [0,2,3])
|
||||
running_var = self.running_var.broadcast(x, [0,2,3])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
||||
if not self.affine:
|
||||
return norm_x
|
||||
w = self.weight.broadcast(x, [0,2,3])
|
||||
b = self.bias.broadcast(x, [0,2,3])
|
||||
return norm_x * w + b
|
||||
|
||||
class BatchNorm1d(Module):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
|
||||
assert affine == None
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, is_train=True, sync=True):
|
||||
self.sync = sync
|
||||
self.num_features = num_features
|
||||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
|
||||
self.affine = affine
|
||||
if self.affine:
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
|
||||
def execute(self, x):
|
||||
if self.is_train:
|
||||
xmean = jt.mean(x, dims=[0], keepdims=1)
|
||||
|
@ -328,20 +333,24 @@ class BatchNorm1d(Module):
|
|||
running_mean = self.running_mean.broadcast(x, [0])
|
||||
running_var = self.running_var.broadcast(x, [0])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
||||
if not self.affine:
|
||||
return norm_x
|
||||
w = self.weight.broadcast(x, [0])
|
||||
b = self.bias.broadcast(x, [0])
|
||||
return norm_x * w + b
|
||||
|
||||
class InstanceNorm2d(Module):
|
||||
def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=None, is_train=True, sync=True):
|
||||
assert affine == None
|
||||
def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, is_train=True, sync=True):
|
||||
self.sync = sync
|
||||
self.num_features = num_features
|
||||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
|
||||
self.affine = affine
|
||||
if self.affine:
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
|
||||
def execute(self, x):
|
||||
xmean = jt.mean(x, dims=[2,3], keepdims=1)
|
||||
|
@ -352,18 +361,22 @@ class InstanceNorm2d(Module):
|
|||
|
||||
xvar = jt.maximum(x2mean-xmean*xmean, 0)
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
if not self.affine:
|
||||
return norm_x
|
||||
w = self.weight.broadcast(x, [0,2,3])
|
||||
b = self.bias.broadcast(x, [0,2,3])
|
||||
return norm_x * w + b
|
||||
|
||||
class GroupNorm(Module):
|
||||
def __init__(self, num_groups, num_channels, eps=1e-05, affine=None, is_train=True):
|
||||
assert affine == None
|
||||
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, is_train=True):
|
||||
self.num_groups = num_groups
|
||||
self.num_channels = num_channels
|
||||
self.eps = eps
|
||||
self.weight = init.constant((num_channels,), "float32", 1.0)
|
||||
self.bias = init.constant((num_channels,), "float32", 0.0)
|
||||
|
||||
self.affine = affine
|
||||
if self.affine:
|
||||
self.weight = init.constant((num_channels,), "float32", 1.0)
|
||||
self.bias = init.constant((num_channels,), "float32", 0.0)
|
||||
|
||||
def execute(self, x):
|
||||
N,C,H,W = x.shape
|
||||
|
@ -374,6 +387,8 @@ class GroupNorm(Module):
|
|||
x2mean = jt.mean(x*x, dims=[2,3], keepdims=1)
|
||||
xvar = jt.maximum(x2mean-xmean*xmean, 0)
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
if not self.affine:
|
||||
return norm_x
|
||||
w = self.weight.reshape((1,self.num_groups,C//self.num_groups,1))
|
||||
b = self.bias.reshape((1,self.num_groups,C//self.num_groups,1))
|
||||
return (norm_x * w + b).reshape((N,C,H,W))
|
||||
|
@ -493,7 +508,11 @@ class ConvTranspose(Module):
|
|||
|
||||
self.weight = init.invariant_uniform((in_channels, out_channels) + self.kernel_size, dtype="float")
|
||||
if bias:
|
||||
self.bias = init.uniform([out_channels], dtype="float", low=-1, high=1)
|
||||
fan=1
|
||||
for i in self.weight.shape[1:]:
|
||||
fan *= i
|
||||
bound = 1 / math.sqrt(fan)
|
||||
self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
|
|
@ -33,6 +33,11 @@ class Optimizer(object):
|
|||
assert isinstance(pg, dict)
|
||||
self.param_groups.append(pg)
|
||||
self.n_step = 0
|
||||
|
||||
@property
|
||||
def defaults(self):
|
||||
exclude = set(("defaults", "param_groups", "n_step"))
|
||||
return { k:v for k, v in self.__dict__.items() if k[0] != '_' and k not in exclude }
|
||||
|
||||
def pre_step(self, loss):
|
||||
""" something should be done before step, such as calc gradients, mpi sync, and so on.
|
||||
|
@ -76,8 +81,12 @@ class Optimizer(object):
|
|||
pg_grads[i] = grads[pid].stop_grad()
|
||||
pid += 1
|
||||
|
||||
def step(self, loss):
|
||||
def backward(self, loss):
|
||||
self.pre_step(loss)
|
||||
|
||||
def step(self, loss=None):
|
||||
if loss is not None:
|
||||
self.pre_step(loss)
|
||||
for pg in self.param_groups:
|
||||
lr = pg.get("lr", self.lr)
|
||||
for p, g in zip(pg["params"], pg["grads"]):
|
||||
|
@ -106,8 +115,9 @@ class SGD(Optimizer):
|
|||
for p in pg["params"]:
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
|
||||
def step(self, loss):
|
||||
self.pre_step(loss)
|
||||
def step(self, loss=None):
|
||||
if loss is not None:
|
||||
self.pre_step(loss)
|
||||
for pg in self.param_groups:
|
||||
# get arguments from each param_groups
|
||||
lr = pg.get("lr", self.lr)
|
||||
|
@ -149,8 +159,9 @@ class RMSprop(Optimizer):
|
|||
for p in pg["params"]:
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
|
||||
def step(self, loss):
|
||||
self.pre_step(loss)
|
||||
def step(self, loss=None):
|
||||
if loss is not None:
|
||||
self.pre_step(loss)
|
||||
for pg in self.param_groups:
|
||||
# get arguments from each param_groups
|
||||
lr = pg.get("lr", self.lr)
|
||||
|
@ -184,8 +195,9 @@ class Adam(Optimizer):
|
|||
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
m.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
|
||||
def step(self, loss):
|
||||
self.pre_step(loss)
|
||||
def step(self, loss=None):
|
||||
if loss is not None:
|
||||
self.pre_step(loss)
|
||||
n = float(self.n_step)
|
||||
for pg in self.param_groups:
|
||||
# get arguments from each param_groups
|
||||
|
|
|
@ -21,7 +21,8 @@ class Linear(Module):
|
|||
self.b = jt.random((out_features,))-0.5 if bias else None
|
||||
def execute(self, x):
|
||||
x = matmul(x, self.w)
|
||||
if self.b: return x+self.b
|
||||
if self.b is not None:
|
||||
return x+self.b
|
||||
return x
|
||||
|
||||
def relu(x):
|
||||
|
|
|
@ -140,10 +140,10 @@ class TestMklConvOp(unittest.TestCase):
|
|||
dw_jt_tune=gs_tune[1].data
|
||||
logs = find_log_with_re(rawlogs,
|
||||
"Run tuner conv: confidence\\((20)\\) candidates\\((.*)\\)$")
|
||||
assert len(logs) == 1
|
||||
assert len(logs) == 2, len(logs)
|
||||
assert logs[0][0] == "20", "confidence of reorder should be 20"
|
||||
candidates = simple_parser(logs[0][1])
|
||||
assert candidates == {"relay0":[1,0],"relay1":[1,0]}, candidates
|
||||
assert candidates == {"relay0":[1,0]}, candidates
|
||||
|
||||
logs = find_log_with_re(rawlogs, r"get_relay_src([\s\S]*)")
|
||||
assert len(logs)==2
|
||||
|
@ -186,10 +186,11 @@ class TestMklConvOp(unittest.TestCase):
|
|||
dw_jt_tune=gs_tune[1].data
|
||||
logs = find_log_with_re(rawlogs,
|
||||
"Run tuner conv: confidence\\((20)\\) candidates\\((.*)\\)$")
|
||||
assert len(logs) == 1
|
||||
assert len(logs) == 2
|
||||
assert logs[0][0] == "20", "confidence of reorder should be 20"
|
||||
candidates = simple_parser(logs[0][1])
|
||||
assert candidates == {"relay0":[1,0],"relay1":[1,0]}, candidates
|
||||
assert candidates == {"relay0":[1,0]}, candidates
|
||||
# assert candidates == {"relay0":[1,0],"relay1":[1,0]}, candidates
|
||||
|
||||
logs = find_log_with_re(rawlogs, r"get_relay_src([\s\S]*)")
|
||||
assert len(logs)==2
|
||||
|
|
|
@ -26,10 +26,10 @@ class FakeMpiBatchNorm(nn.Module):
|
|||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
|
||||
def execute(self, x, global_x):
|
||||
if self.is_train:
|
||||
|
@ -91,6 +91,7 @@ class TestMpiBatchnorm(unittest.TestCase):
|
|||
gs2 = jt.grad(y2,bn2.parameters())
|
||||
|
||||
assert np.allclose(y1.data, y2.data, atol=1e-5),(mpi.world_rank(),y1.data, y2.data, y1.data-y2.data)
|
||||
assert len(gs1) == len(gs2)
|
||||
for i in range(len(gs1)):
|
||||
assert np.allclose(gs1[i].data, gs2[i].data, rtol=1e-2),(mpi.world_rank(),gs1[i].data, gs2[i].data,gs1[i].data-gs2[i].data)
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ class TestNanoString(unittest.TestCase):
|
|||
# t is about 0.01 for 100w loop
|
||||
# 92ns one loop
|
||||
print("nanostring time", t)
|
||||
assert t < [1.5e-7, 1.7e-7][mid], t
|
||||
assert t < [1.5e-7, 1.9e-7][mid], t
|
||||
|
||||
assert (jt.hash("asdasd") == 4152566416)
|
||||
assert str(jt.NanoString("float"))=="float32"
|
||||
|
|
|
@ -4,6 +4,7 @@ import pickle
|
|||
import numpy as np
|
||||
import jittor_utils
|
||||
from jittor_utils import LOG
|
||||
import sys
|
||||
|
||||
jittor_utils.try_import_jit_utils_core()
|
||||
|
||||
|
@ -20,11 +21,58 @@ def convert(data):
|
|||
if isinstance(data, dict):
|
||||
return {k:convert(data[k]) for k in data}
|
||||
if hasattr(data, "numpy"):
|
||||
return data.detach().numpy()
|
||||
if "Var" in data.__class__.__name__:
|
||||
return data.numpy()
|
||||
else:
|
||||
return data.detach().cpu().numpy()
|
||||
return data
|
||||
|
||||
rand_hooked = False
|
||||
|
||||
def hook_pt_rand(*shape):
|
||||
import torch
|
||||
if isinstance(shape, tuple) and len(shape)==1 and isinstance(shape[0], torch.Size):
|
||||
shape = tuple(shape[0])
|
||||
np.random.seed(0)
|
||||
return torch.from_numpy(np.random.rand(*tuple(shape)).astype("float32"))
|
||||
|
||||
def hook_pt_normal(mean, std):
|
||||
import torch
|
||||
shape = tuple(mean.shape)
|
||||
|
||||
np.random.seed(0)
|
||||
return torch.from_numpy(np.random.normal(size=shape).astype("float32")).to(std.device) * std + mean
|
||||
|
||||
def hook_jt_rand(shape, dtype="float32", rtype="uniform"):
|
||||
import jittor
|
||||
np.random.seed(0)
|
||||
if rtype == "normal":
|
||||
return jittor.array(np.random.normal(size=shape).astype(str(dtype)))
|
||||
return jittor.array(np.random.rand(*shape).astype(str(dtype)))
|
||||
|
||||
def hook_rand():
|
||||
global rand_hooked
|
||||
if rand_hooked: return
|
||||
rand_hooked = True
|
||||
np.random.seed(0)
|
||||
if "torch" in sys.modules:
|
||||
LOG.i("Hook torch.rand")
|
||||
torch = sys.modules["torch"]
|
||||
torch.rand = hook_pt_rand
|
||||
torch.normal = hook_pt_normal
|
||||
torch.manual_seed(0)
|
||||
if "jittor" in sys.modules:
|
||||
jittor = sys.modules["jittor"]
|
||||
LOG.i("Hook jittor.random")
|
||||
jittor.random = hook_jt_rand
|
||||
jittor.seed(0)
|
||||
|
||||
|
||||
class Hook:
|
||||
def __init__(self, base_name, rtol=5e-2, atol=1e-3):
|
||||
if os.environ.get("use_auto_diff", '1') == '0':
|
||||
return
|
||||
hook_rand()
|
||||
self.rid = 0
|
||||
self.base_name = base_name
|
||||
self.base_path = os.path.join(str(Path.home()), ".cache", "jittor", "auto_diff", base_name)
|
||||
|
@ -47,16 +95,19 @@ class Hook:
|
|||
return
|
||||
|
||||
has_error += 1
|
||||
LOG.e(f"Ndarray <{name}> not match, shape:{a.shape}")
|
||||
LOG.w(f"Ndarray <{name}> not match, shape:{a.shape}")
|
||||
i = tuple( i[0] for i in index )
|
||||
err_rate = is_error.mean()
|
||||
LOG.e(f"error index at [{i}], a({a[i]}) b({b[i]}) err({err[i]}) > tol({tol[i]}), err_rate:{err_rate*100:.3f}%")
|
||||
LOG.w(f"error index at [{i}], a({a[i]}) b({b[i]}) err({err[i]}) > tol({tol[i]}), err_rate:{err_rate*100:.3f}% amean({a.mean()}) bmean({b.mean()}) astd({a.std()}) bstd({b.std()}) ")
|
||||
if err_rate > 0.01:
|
||||
LOG.e("!"*10+"Very HIGH err rate"+"!"*10)
|
||||
|
||||
def check(self, name, pre_data, data):
|
||||
global has_error
|
||||
assert type(pre_data) == type(data)
|
||||
if type(pre_data) != type(data):
|
||||
LOG.e(f"type not match, {pre_data.__class__.__name__}!={data.__class__.__name__}, name: {name}")
|
||||
has_error += 1
|
||||
return
|
||||
if isinstance(pre_data, (list, tuple)):
|
||||
if len(pre_data) != len(data):
|
||||
has_error += 1
|
||||
|
@ -75,19 +126,26 @@ class Hook:
|
|||
elif isinstance(pre_data, dict):
|
||||
if len(pre_data) != len(data):
|
||||
has_error += 1
|
||||
LOG.e(f"Dict Name <{name}> len not match, {len(pre_data)} != {len(data)}")
|
||||
LOG.w(f"Dict Name <{name}> len not match, {len(pre_data)} != {len(data)}")
|
||||
for k in pre_data:
|
||||
pv = pre_data[k]
|
||||
if k not in data:
|
||||
has_error += 1
|
||||
LOG.e(f"Key <{k}> not in data, Name <{name}>")
|
||||
msg = f"Key <{k}> not in data, Name <{name}>"
|
||||
if isinstance(pv, np.ndarray):
|
||||
LOG.e(msg)
|
||||
else:
|
||||
LOG.w(msg)
|
||||
continue
|
||||
self.check(name+f".{i}", pre_data[k], data[k])
|
||||
self.check(name+f".{k}", pre_data[k], data[k])
|
||||
else:
|
||||
if pre_data != data:
|
||||
has_error += 1
|
||||
LOG.e(f"Type: {type(pre_data).__name__} Name <{name}> not match {pre_data} != {data}")
|
||||
|
||||
def record(self, name, data):
|
||||
def record(self, name, data, ex_name=""):
|
||||
if os.environ.get("use_auto_diff", '1') == '0':
|
||||
return
|
||||
rid = self.rid
|
||||
self.rid += 1
|
||||
fpath = os.path.join(self.base_path, f"{rid}.pkl")
|
||||
|
@ -99,15 +157,18 @@ class Hook:
|
|||
global has_error
|
||||
has_error += 1
|
||||
LOG.e(f"The {rid} result name not match, {pre_name} != {name}")
|
||||
self.rid -= 1
|
||||
return
|
||||
LOG.i(f"check {rid}:<{name}> ...")
|
||||
self.check(name, pre_data, data)
|
||||
LOG.i(f"check {rid}:<{ex_name}{name}> ...")
|
||||
self.check(ex_name+name, pre_data, data)
|
||||
else:
|
||||
with open(fpath, 'wb') as f:
|
||||
pickle.dump((name, data), f)
|
||||
LOG.i(f"save {rid}:<{name}> ok")
|
||||
|
||||
def record_params(self, parameters_dict):
|
||||
if os.environ.get("use_auto_diff", '1') == '0':
|
||||
return
|
||||
rid = self.rid
|
||||
self.rid += 1
|
||||
global has_error
|
||||
|
@ -162,23 +223,75 @@ class Hook:
|
|||
LOG.i(f"save params ok")
|
||||
|
||||
|
||||
def hook_module(self, mod):
|
||||
def hook_module(self, mod, mod_name=""):
|
||||
if os.environ.get("use_auto_diff", '1') == '0':
|
||||
return
|
||||
if mod_name != "":
|
||||
mod_name = "<" + mod_name + ">"
|
||||
def forward_hook(self2, input, output):
|
||||
ex_name = '[' + self2.__class__.__name__ + ']'
|
||||
if "relu" not in self2.__class__.__name__.lower():
|
||||
# not test relu, because input may be inplaced
|
||||
self.record(self2.__ad_mod_name__+".input", input)
|
||||
self.record(self2.__ad_mod_name__+".output", output)
|
||||
self.record(self2.__ad_mod_name__+".input", input, ex_name)
|
||||
self.record(self2.__ad_mod_name__+".output", output, ex_name)
|
||||
|
||||
names = []
|
||||
for name, module in mod.named_modules():
|
||||
name = mod_name + name
|
||||
module.__ad_mod_name__ = name
|
||||
names.append(name)
|
||||
module.register_forward_hook(forward_hook)
|
||||
self.record_params(mod.state_dict())
|
||||
mod_class_name = module.__class__.__name__.lower()
|
||||
# make dropout in eval mod and record dropout.p
|
||||
if "dropout" in mod_class_name:
|
||||
self.record(name+'.p', module.p, "["+mod_class_name+"]")
|
||||
module.eval()
|
||||
ps = { mod_name+k:v for k, v in mod.state_dict().items() }
|
||||
self.record_params(ps)
|
||||
self.record("module names", names)
|
||||
|
||||
def hook_optimizer(self, opt, opt_name=""):
|
||||
'''
|
||||
net = Model()
|
||||
opt = optim.SGD(net.parameters(), 0.1)
|
||||
hook.hook_optimizer(opt)
|
||||
'''
|
||||
if os.environ.get("use_auto_diff", '1') == '0':
|
||||
return
|
||||
origin_step = opt.step
|
||||
ex_name = '['+opt.__class__.__name__+']'
|
||||
def step_hook(*args, **kw):
|
||||
origin_step(*args, **kw)
|
||||
self.record(opt_name+".default", opt.defaults, ex_name)
|
||||
gid = 0
|
||||
n_params = 0
|
||||
for pg in opt.param_groups:
|
||||
for p in pg["params"]:
|
||||
if hasattr(p, "is_stop_grad"):
|
||||
if p.is_stop_grad():
|
||||
continue
|
||||
n_params += 1
|
||||
else:
|
||||
n_params += 1
|
||||
|
||||
self.record(opt_name+".n_params", n_params, ex_name)
|
||||
|
||||
for pg in opt.param_groups:
|
||||
for i, p in reversed(list(enumerate(pg["params"]))):
|
||||
if hasattr(p, "is_stop_grad"):
|
||||
if p.is_stop_grad():
|
||||
continue
|
||||
self.record(f"{opt_name}.grads[{gid}]", pg["grads"][i], "["+p.name()+"]")
|
||||
self.record(f"{opt_name}.params[{gid}]", p, "["+p.name()+"]")
|
||||
gid += 1
|
||||
else:
|
||||
self.record(f"{opt_name}.grads[{gid}]", p.grad)
|
||||
self.record(f"{opt_name}.params[{gid}]", p)
|
||||
gid += 1
|
||||
opt.step = step_hook
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -464,7 +464,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
checkCudaErrors(cudaDeviceSynchronize());
|
||||
}));
|
||||
}
|
||||
LOGvv << "cudaDeviceSynchronize times:" << sync_times << "/" <<queue.size();
|
||||
LOGvv << "cudaDeviceSynchronize times:" << sync_times << "/" <<queue.size() << "device_sync:" << device_sync;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
@ -41,6 +41,10 @@ void Op::forward(Var* input) {
|
|||
outputs_holder.emplace_back(input);
|
||||
}
|
||||
|
||||
VarPtr Op::duplicate() {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
VarPtr Op::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
LOGw << "Grad of" << name() << "return zeros";
|
||||
return nullptr;
|
||||
|
|
1
src/op.h
1
src/op.h
|
@ -47,6 +47,7 @@ struct Op : Node {
|
|||
virtual void do_prepare();
|
||||
virtual void do_run_after_prepare();
|
||||
virtual void do_run();
|
||||
virtual VarPtr duplicate();
|
||||
void jit_run();
|
||||
|
||||
string name_ex() const;
|
||||
|
|
|
@ -87,6 +87,18 @@ BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) {
|
|||
z = create_output(nullptr, binary_dtype_infer(op, x, y));
|
||||
}
|
||||
|
||||
VarPtr dirty_clone_broadcast(Var* v) {
|
||||
Op* op = v->input();
|
||||
// dirty fix conv duplicated
|
||||
if (op && !v->is_finished() && v->shape.size() > 4 && op->type() == OpType::broadcast) {
|
||||
auto vp = op->duplicate();
|
||||
if (vp) {
|
||||
return vp;
|
||||
}
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
VarPtr BinaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
if (ns == ns_add) return dout;
|
||||
if (ns == ns_subtract) {
|
||||
|
@ -97,9 +109,9 @@ VarPtr BinaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
}
|
||||
if (ns == ns_multiply) {
|
||||
if (v_index == 0)
|
||||
return make_binary(y, dout, ns_multiply);
|
||||
return make_binary(dirty_clone_broadcast(y), dirty_clone_broadcast(dout), ns_multiply);
|
||||
else
|
||||
return make_binary(x, dout, ns_multiply);
|
||||
return make_binary(dirty_clone_broadcast(x), dirty_clone_broadcast(dout), ns_multiply);
|
||||
}
|
||||
if (ns == ns_divide) {
|
||||
if (v_index == 0)
|
||||
|
|
|
@ -14,6 +14,10 @@ namespace jittor {
|
|||
#ifndef JIT
|
||||
static auto make_reduce = get_op_info("reduce")
|
||||
.get_constructor<VarPtr, Var*, NanoString, uint, uint>();
|
||||
static auto make_broadcast = get_op_info("broadcast_to")
|
||||
.get_constructor<VarPtr, Var*, Var*, uint, uint>();
|
||||
static auto make_broadcast2 = get_op_info("broadcast_to")
|
||||
.get_constructor<VarPtr, Var*, NanoVector, uint, uint>();
|
||||
|
||||
BroadcastToOp::BroadcastToOp(Var* x, Var* y, NanoVector dims) : x(x), y(y) {
|
||||
auto count = dims.size();
|
||||
|
@ -51,6 +55,20 @@ BroadcastToOp::BroadcastToOp(Var* x, Var* y, uint dims_mask, uint keepdims_mask)
|
|||
this->keepdims_mask = keepdims_mask;
|
||||
}
|
||||
|
||||
BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, uint dims_mask, uint keepdims_mask) : x(x), y(nullptr), shape(shape) {
|
||||
auto count = __builtin_popcount(dims_mask);
|
||||
if (!count) {
|
||||
forward(x);
|
||||
return;
|
||||
}
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
set_type(OpType::broadcast);
|
||||
z = create_output(NanoVector(), x->dtype());
|
||||
bcast_mask = dims_mask;
|
||||
this->keepdims_mask = keepdims_mask;
|
||||
}
|
||||
|
||||
BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x), y(nullptr), shape(shape) {
|
||||
auto count = dims.size();
|
||||
// forward x if don't need broadcast
|
||||
|
@ -82,6 +100,13 @@ bool BroadcastToOp::need_broadcast(const Var* x, const NanoVector& shape) {
|
|||
return false;
|
||||
}
|
||||
|
||||
VarPtr BroadcastToOp::duplicate() {
|
||||
if (y)
|
||||
return make_broadcast(x, y, bcast_mask, keepdims_mask);
|
||||
else
|
||||
return make_broadcast2(x, shape, bcast_mask, keepdims_mask);
|
||||
}
|
||||
|
||||
VarPtr BroadcastToOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
if (v_index==1) return nullptr;
|
||||
if (bcast_mask==0) return dout;
|
||||
|
|
|
@ -21,12 +21,15 @@ struct BroadcastToOp : Op {
|
|||
BroadcastToOp(Var* x, Var* y, NanoVector dims=NanoVector());
|
||||
// @pybind(None)
|
||||
BroadcastToOp(Var* x, Var* y, uint dims_mask, uint keepdims_mask);
|
||||
// @pybind(None)
|
||||
BroadcastToOp(Var* x, NanoVector shape, uint dims_mask, uint keepdims_mask);
|
||||
|
||||
bool need_broadcast(const Var* x, const NanoVector& shape);
|
||||
|
||||
const char* name() const override { return "broadcast_to"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
VarPtr duplicate() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
|
|
|
@ -13,6 +13,8 @@ namespace jittor {
|
|||
#ifndef JIT
|
||||
static auto make_reindex_reduce = get_op_info("reindex_reduce")
|
||||
.get_constructor<VarPtr, Var*, NanoString, NanoVector, vector<string>&&, vector<string>&&, vector<Var*>&&>();
|
||||
static auto make_reindex = get_op_info("reindex")
|
||||
.get_constructor<VarPtr, Var*, NanoVector, vector<string>&&, float64, vector<string>&&, vector<Var*>&&>();
|
||||
|
||||
ReindexOp::ReindexOp(Var* x, NanoVector shape, vector<string>&& indexes, float64 overflow_value, vector<string>&& overflow_conditions, vector<Var*>&& extras)
|
||||
: x(x),
|
||||
|
@ -63,6 +65,10 @@ ReindexOp::ReindexOp(Var* x, vector<Var*>&& indexes, float64 overflow_value, vec
|
|||
}
|
||||
}
|
||||
|
||||
VarPtr ReindexOp::duplicate() {
|
||||
return make_reindex(x, shape, clone(indexes), overflow_value, clone(overflow_conditions), clone(extras));
|
||||
}
|
||||
|
||||
VarPtr ReindexOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
// Do not have grad to extras input
|
||||
if (v_index) return nullptr;
|
||||
|
|
|
@ -98,6 +98,7 @@ struct ReindexOp : Op {
|
|||
const char* name() const override { return "reindex"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
VarPtr duplicate() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
|
|
|
@ -50,19 +50,23 @@ void __unregiste_node_trace(Node* node) {
|
|||
|
||||
void __registe_node_trace_grad(Node* g, Node* node, int x_index) {
|
||||
if (!g) return;
|
||||
string& gname = trace_data.at(g);
|
||||
string name = "grad(";
|
||||
if (startswith(gname, "grad("))
|
||||
return;
|
||||
if (!node->is_var()) {
|
||||
name += node->op()->name_ex();
|
||||
name += ':';
|
||||
name += S(x_index);
|
||||
}
|
||||
name += ":" + gname;
|
||||
name += "):";
|
||||
name += trace_data.at(node);
|
||||
trace_data[g] = name;
|
||||
gname = name;
|
||||
std::function<void(Node*)> dfs = [&] (Node* node) {
|
||||
for (Node* i : node->inputs()) {
|
||||
string& iname = trace_data[i];
|
||||
if (startswith(iname, "__init__.py:grad:")) {
|
||||
if (iname.find("__init__.py:grad:") != string::npos && !startswith(iname, "grad(")) {
|
||||
iname = name;
|
||||
dfs(i);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue